Cross Entropy Methodで超簡単に攻略するCartPole(初心者向け)
こんにちは。 強化学習のシュミレータとして最も有名なものの1つがgym環境だと思います。中でもclassical controlのCartPoleは初心者が最初に取り組むものでしょう。 インターネット上にもこれを基本的なQ learningや、中には深層強化学習を用いて学習させているものもあります。
CartPoleはとても簡単な環境であるため、難しい強化学習手法を用いるまでもありません(もちろん実装の練習にはいいと思います)。 今日はCross Entropy Methodと呼ばれる進化計算的な手法を使って攻略してみたいと思います。
Cross Entropy Method
Cross Entropy Methodの概略は以下の通りです。
- パラメータの平均 、分散 を初期化
- 平均 、分散 の正規分布に従って重みパラメータをn_population個生成
- n_population個のそれぞれのパラメータで独立にゲームプレイする
- n_population個のパラメータのうち、成績がよかったn_elite個のパラメータの平均と分散を新たな, とする
- 2~4を繰り返す
原始的な方法でとてもシンプルだと思います。 では具体的な実験に移ります。
実験環境
必要なライブラリはgymnasium, numpyのみです。適宜pipなどでインストールしてください。 gymnasiumはOpenAI gymの後継プロジェクトです(今回の内容だけならgymでもいいです)
実装
今回は行動を決定するモデルとして、シンプルにobservationを入力とする1層ニューラルネットワークを考えます。 CartPoleのobservationの次元数および可能な行動数は以下で確認できるようにそれぞれ4と2です。
import gymnasium as gym env = gym.make("CartPole-v1") n_features = env.observation_space.shape[0] n_actions = env.action_space.n print(n_features) print(n_actions) # 4 # 2
よって今回学習するパラメータの形状は(4+1 (bias項), 2)となります。つまりたったの10個です。非線形ですらないので、これでできるのかという気もしますができます。
まず平均と分散 (ここでは標準偏差) から人口分のパラメータを生成する関数です。(手順2に必要)
def get_batch_weights( mean: np.ndarray, stddev: float, pop_num: int ) -> np.ndarray: return [np.random.normal(mean, stddev) for _ in range(pop_num)]
状態から行動を取得する関数は以下のように実装できます。 単に重み行列をかけて大きな値を返す方を選んでいます。
def get_action(state: np.ndarray, weight: np.ndarray) -> int: return np.argmax(state @ weight[:n_features] + weight[n_features])
あるパラメータのもとで1エピソードこなして報酬和を得る関数を実装します。(手順3に必要)
def run(env: gym.Env, weight: np.ndarray) -> int: state, _ = env.reset() ret = 0 while True: action = get_action(state, weight) state, reward, terminated, truncated, _ = env.step(action) ret += reward if terminated or truncated: break return ret
成績上位のパラメータをいくつか集めてその平均と標準偏差を求めます。
def get_new_mean_stddev( batch_weights: List[np.ndarray], rets: List[int], elite_num: int ) -> tuple[np.ndarray, np.ndarray]: idx = np.argsort(rets)[:elite_num] elite_weights = np.array(batch_weights)[idx.astype(int)] mean = np.mean(elite_weights, axis=0) stddev = np.sqrt(np.var(elite_weights, axis=0)) return (mean, stddev)
これで準備ができました。以下がメインコードです。 今回は100個体を走らせてそのうち10%の成績優秀者を次世代の平均、標準偏差の獲得に用いました。
mean = np.zeros(shape=(n_features+1, n_actions)) stddev = np.ones(shape=(n_features+1, n_actions)) num_update = 10 pop_num = 100 elite_rate = 0.1 elite_num = int(pop_num * elite_rate) for i in range(num_update): batch_weights = get_batch_weights(mean, stddev, pop_num) rets = [-1 * run(env, weight) for weight in batch_weights] mean, stddev = get_new_mean_stddev(batch_weights, rets, elite_num) print(f"generation: {i+1}, average score: {-np.mean(rets)}") print("final parameters") print("mean:") print(mean) print("stddev:") print(stddev)
実行すると下のような結果が得られました。10世代分しか回していないですが十分ですね。
generation: 1, average score: 16.99 generation: 2, average score: 37.69 generation: 3, average score: 111.54 generation: 4, average score: 211.78 generation: 5, average score: 300.22 generation: 6, average score: 412.23 generation: 7, average score: 439.32 generation: 8, average score: 466.04 generation: 9, average score: 453.44 generation: 10, average score: 477.09 final parameters mean: [[ 0.46792536 0.67930319] [-1.03678675 0.83203438] [-1.54660413 0.60338747] [-1.09715425 1.39296107] [ 0.12084092 0.14781027]] stddev: [[0.26158623 0.15344902] [0.48321259 0.17085536] [0.43412831 0.24177063] [0.17802083 0.19848629] [0.03832805 0.03300419]]
Poetryに関するTips
Poetryの使い方を簡単にまとめました。後半では自分がつまづいた箇所の解決法を紹介しています。
基本コマンド
$ poetry new <project-name>
新たにプロジェクトを作成します。
$ poetry install
仮想環境を立ち上げます。
$ poetry shell
仮想環境を有効にします。
$ poetry add <package-name>
仮想環境にパッケージを追加します。
以下では知っておくと役に立つかもしれない情報を紹介します。
仮想環境の配置場所について
仮想環境はプロジェクトディレクトリで
$ poetry env info
を打つことで確認でき下のような結果が得られると思います。(公式サイトの出力例)
Virtual environment Python: 3.7.1 Implementation: CPython Path: /path/to/poetry/cache/virtualenvs/test-O3eWbxRl-py3.7 Valid: True System Platform: darwin OS: posix Python: /path/to/main/python
デフォルトの設定だと仮想環境はプロジェクト配下ではなく、別の場所に配置されます。 仮想環境の配置場所をプロジェクト配下にしたければ
$ poetry config virtualenvs.in-project true
とします。 またvirtualenvs.in-projectがfalseの場合は
$ poetry config virtualenvs.path
によって確認できるパスに収められるようです。
仮想環境のpythonのバージョンについて
仮想環境で使用されるpythonのバージョンは上で説明したpoetry env infoの出力から確認できます。 私の環境ではpyenvでインストールしたpythonを有効にしているにも関わらず、デフォルトではシステムに元々備わっているpythonが使われるようでした。 これは
$ poetry config virtualenvs.prefer-active-python true
によってpoetryがshellで有効になっているpythonを見つけてくれるようです。