Cross Entropy Methodで超簡単に攻略するCartPole(初心者向け)

こんにちは。 強化学習のシュミレータとして最も有名なものの1つがgym環境だと思います。中でもclassical controlのCartPoleは初心者が最初に取り組むものでしょう。 インターネット上にもこれを基本的なQ learningや、中には深層強化学習を用いて学習させているものもあります。

CartPoleはとても簡単な環境であるため、難しい強化学習手法を用いるまでもありません(もちろん実装の練習にはいいと思います)。 今日はCross Entropy Methodと呼ばれる進化計算的な手法を使って攻略してみたいと思います。

Cross Entropy Method

Cross Entropy Methodの概略は以下の通りです。

  1. パラメータの平均  \mu、分散  \sigma ^2を初期化
  2. 平均  \mu、分散  \sigma ^2正規分布に従って重みパラメータをn_population個生成
  3. n_population個のそれぞれのパラメータで独立にゲームプレイする
  4. n_population個のパラメータのうち、成績がよかったn_elite個のパラメータの平均と分散を新たな \mu,  \sigma ^2とする
  5. 2~4を繰り返す

原始的な方法でとてもシンプルだと思います。 では具体的な実験に移ります。

実験環境

必要なライブラリはgymnasium, numpyのみです。適宜pipなどでインストールしてください。 gymnasiumはOpenAI gymの後継プロジェクトです(今回の内容だけならgymでもいいです)

実装

今回は行動を決定するモデルとして、シンプルにobservationを入力とする1層ニューラルネットワークを考えます。 CartPoleのobservationの次元数および可能な行動数は以下で確認できるようにそれぞれ4と2です。

import gymnasium as gym
env = gym.make("CartPole-v1")
n_features = env.observation_space.shape[0]
n_actions = env.action_space.n

print(n_features)
print(n_actions)
# 4
# 2

よって今回学習するパラメータの形状は(4+1 (bias項), 2)となります。つまりたったの10個です。非線形ですらないので、これでできるのかという気もしますができます。

まず平均と分散 (ここでは標準偏差) から人口分のパラメータを生成する関数です。(手順2に必要)

def get_batch_weights(
    mean: np.ndarray,
    stddev: float,
    pop_num: int
) -> np.ndarray:
    
    return [np.random.normal(mean, stddev) for _ in range(pop_num)]

状態から行動を取得する関数は以下のように実装できます。 単に重み行列をかけて大きな値を返す方を選んでいます。

def get_action(state: np.ndarray, weight: np.ndarray) -> int: 
    return np.argmax(state @ weight[:n_features] + weight[n_features])

あるパラメータのもとで1エピソードこなして報酬和を得る関数を実装します。(手順3に必要)

def run(env: gym.Env, weight: np.ndarray) -> int:
    state, _ = env.reset()
    ret = 0
    while True:
        action = get_action(state, weight)
        state, reward, terminated, truncated, _ = env.step(action)
        ret += reward
        if terminated or truncated:
            break
    return ret

成績上位のパラメータをいくつか集めてその平均と標準偏差を求めます。

def get_new_mean_stddev(
    batch_weights: List[np.ndarray],
    rets: List[int],
    elite_num: int
) -> tuple[np.ndarray, np.ndarray]:
    
    idx = np.argsort(rets)[:elite_num]
    elite_weights = np.array(batch_weights)[idx.astype(int)]
    mean = np.mean(elite_weights, axis=0)
    stddev = np.sqrt(np.var(elite_weights, axis=0))

    return (mean, stddev)

これで準備ができました。以下がメインコードです。 今回は100個体を走らせてそのうち10%の成績優秀者を次世代の平均、標準偏差の獲得に用いました。

mean = np.zeros(shape=(n_features+1, n_actions))
stddev = np.ones(shape=(n_features+1, n_actions))


num_update = 10
pop_num = 100
elite_rate = 0.1
elite_num = int(pop_num * elite_rate)
for i in range(num_update):
    batch_weights = get_batch_weights(mean, stddev, pop_num)
    rets = [-1 * run(env, weight) for weight in batch_weights]
    mean, stddev = get_new_mean_stddev(batch_weights, rets, elite_num)
    print(f"generation: {i+1}, average score: {-np.mean(rets)}")

print("final parameters")
print("mean:")
print(mean)
print("stddev:")
print(stddev)

実行すると下のような結果が得られました。10世代分しか回していないですが十分ですね。

generation: 1, average score: 16.99
generation: 2, average score: 37.69
generation: 3, average score: 111.54
generation: 4, average score: 211.78
generation: 5, average score: 300.22
generation: 6, average score: 412.23
generation: 7, average score: 439.32
generation: 8, average score: 466.04
generation: 9, average score: 453.44
generation: 10, average score: 477.09
final parameters
mean:
[[ 0.46792536  0.67930319]
 [-1.03678675  0.83203438]
 [-1.54660413  0.60338747]
 [-1.09715425  1.39296107]
 [ 0.12084092  0.14781027]]
stddev:
[[0.26158623 0.15344902]
 [0.48321259 0.17085536]
 [0.43412831 0.24177063]
 [0.17802083 0.19848629]
 [0.03832805 0.03300419]]

Poetryに関するTips

Poetryの使い方を簡単にまとめました。後半では自分がつまづいた箇所の解決法を紹介しています。

基本コマンド

$ poetry new <project-name>

新たにプロジェクトを作成します。

$ poetry install 

仮想環境を立ち上げます。

$ poetry shell

仮想環境を有効にします。

$ poetry add <package-name>

仮想環境にパッケージを追加します。

以下では知っておくと役に立つかもしれない情報を紹介します。

仮想環境の配置場所について

仮想環境はプロジェクトディレクトリで

$ poetry env info

を打つことで確認でき下のような結果が得られると思います。(公式サイトの出力例)

Virtual environment
Python:         3.7.1
Implementation: CPython
Path:           /path/to/poetry/cache/virtualenvs/test-O3eWbxRl-py3.7
Valid:          True

System
Platform: darwin
OS:       posix
Python:   /path/to/main/python

デフォルトの設定だと仮想環境はプロジェクト配下ではなく、別の場所に配置されます。 仮想環境の配置場所をプロジェクト配下にしたければ

$ poetry config virtualenvs.in-project true

とします。 またvirtualenvs.in-projectがfalseの場合は

$ poetry config virtualenvs.path

によって確認できるパスに収められるようです。

仮想環境のpythonのバージョンについて

仮想環境で使用されるpythonのバージョンは上で説明したpoetry env infoの出力から確認できます。 私の環境ではpyenvでインストールしたpythonを有効にしているにも関わらず、デフォルトではシステムに元々備わっているpythonが使われるようでした。 これは

$ poetry config virtualenvs.prefer-active-python true

によってpoetryがshellで有効になっているpythonを見つけてくれるようです。