查看原文
其他

【他山之石】技术总结《OpenAI Gym》

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:Criss

转自:https://www.meltycriss.com/2018/03/26/tech-gym/

本文首先介绍Gym的核心函数调用链,然后介绍如何创建自定义的Gym环境,最后给出一些使用Gym过程中碰到的问题及其解决方案


01

Gym核心函数调用链
一般来说,使用Gym的代码如下:
# main.pyimport gymdef choose_action(o): ...env = gym.make('CartPole-v0')o = env.reset()while True: a = choose_action(o) o_, r, done, info = env.step(a) o = o_ if done: break
可见,关键的函数有:
  • env = gym.make('CartPole-v0')

  • env.reset()

  • env.step(a)

我们先关注env.reset()和env.step(a)。这两个函数是超类Env的成员函数,Env的相关代码如下:
# gym/core.pyclass Env(object): ... # Override in ALL subclasses def _step(self, action): raise NotImplementedError def _reset(self): raise NotImplementedError ... def step(self, action): return self._step(action) def reset(self): return self._reset() ...
可以看到这两个函数依赖于子类的_reset(self)和_step(self, action)实现,子类CartPoleEnv的相关代码如下:
# gym/envs/classic_control/CartPole.pyclass CartPoleEnv(gym.Env): ... def _step(self, action): ... def _reset(self): ... ...
综上,env.reset()和env.step(a)实际上是调用子类的_reset(self)和_step(self, action)。
下面我们关注gym.make('CartPole-v0'),它的实现如下:
# gym/envs/registration.py# Have a global registryregistry = EnvRegistry()...def make(id): return registry.make(id)
可以看到gym.make依赖于类EnvRegistry的成员函数make,EnvRegistry的相关代码如下:
# gym/envs/registration.pyclass EnvRegistry(object): def __init__(self): # 注册表 # key: 环境名称(e.g., 'CartPole-v0') # value:类型为EnvSpec,可以暂时理解为环境 self.env_specs = {} def make(self, id): ... # 根据环境名称,通过成员函数找到对应的环境 spec = self.spec(id) # 实例化环境 env = spec.make() ... return env ... def spec(self, id): ... ...
可见类EnvRegistry的成员函数make依赖于类EnvSpec的成员函数make,EnvSpec的相关代码如下:
# gym/envs/registration.pydef load(name): ...# EnvSpec与Env之间的关系类似于说明商品规格的订单与商品之间的关系,# 下面用一个例子来说明:# 假设你网购看中了一款衣服,那么你会挑选该款衣服的颜色、码数,然后再下单。# 在这个例子里面,那款衣服就是Env,而说明该款衣服颜色、码数的订单就是EnvSpec。# 这就是为什么EnvRegistry.make(self, id)中,在得到spec之后还要再spec.make(),# 因为EnvSpec并不是Env,正如订单不是衣服。class EnvSpec(object): def __init__(self, id, entry_point=None, ...): self.id = id ... self._entry_point = entry_point ... def make(self): ... # 动态加载环境类 # 相当于以下代码 # from self._entry_point import classA # cls = classA cls = load(self._entry_point) # 实例化环境 env = cls(**self._kwargs) ... return env ...
至此,我们对Gym的核心函数调用链有了一个基本的了解:
  • gym.make(id):通过EnvRegistry中的注册表找到对应的EnvSpec,EnvSpec根据entry_point动态import对应的Env,并将其实例化;
  • env.reset()和env.step(a):子类的_reset(self)和_step(self, action)。


02

创建自定义环境
对Gym的核心函数调用链有了基本了解后,我们知道创建自定义环境的关键有两个:
  • 第一个是搭建自己的Env子类FooEnv;
  • 第二个是注册FooEnv(i.e., 将FooEnv添加到registry.env_specs中),使得gym.make(id)可以找到FooEnv。
官方文档推荐的自定义环境目录结构如下:
gym-foo/ README.md setup.py #将gym_foo这个package加到系统环境变量中 gym_foo/ #核心部分 __init__.py #注册FooEnv envs/ __init__.py foo_env.py #实现FooEnv
实现FooEnv没什么特别的,就是根据自己的需求,实现_step(self, action)、_reset(self)等函数。
值得一提的是注册FooEnv,我们无需自己实现注册环境的代码,因为Gym已经有现成的注册环境API,我们只需要调用该API即可。在我们的自定义环境中,负责注册FooEnv的文件为gym-foo/gym_foo/__init__.py,它的内容如下:
# gym-foo/gym_foo/__init__.pyfrom gym.envs.registration import registerregister( id='foo-v0', # 环境名 entry_point='gym_foo.envs:FooEnv', # 环境类,之后就根据这个路径动态import环境)
可见,注册的关键是register函数,而register函数的实现如下:
# gym/envs/registration.py# Have a global registryregistry = EnvRegistry()# Gym的注册环境APIdef register(id, **kwargs): return registry.register(id, **kwargs)def make(id): return registry.make(id)
可以看到register的实现依赖于类EnvRegistry的成员函数register,其相关代码如下:
# gym/envs/registration.pyclass EnvRegistry(object): ... def register(self, id, **kwargs): ... # 将FooEnv对应的“订单”写到“注册表”上 self.env_specs[id] = EnvSpec(id, **kwargs)
综上,我们可以通过API函数register注册自定义的环境FooEnv。

03

注意事项

3.1 server render

假如你通过ssh连接server,在server上运行(i.e., python main.py)以下代码(关键点在使用env.render()保存录像):
# main.pyimport gymfrom gym import wrappersenv = gym.make('CartPole-v0')env = wrappers.Monitor(env, 'video')for i_episode in range(20): observation = env.reset() for t in range(100): env.render() action = env.action_space.sample() observation, reward, done, info = env.step(action) if done: break
那么你会得到一个报错,报错的信息大概是pyglet.canvas.xlib.NoSuchDisplayException: Cannot connect to "None"。
原因大概是env.render()需要图形界面(就是弹出来的那个框框),而当你使用ssh连接server时是没有图形界面的。因此我们需要一个虚拟的图形界面,而xvfb-run就是一个提供虚拟图形界面的工具。
所以我们需要使用xvfb-run -a -s "-screen 0 1400x900x24 +extension RANDR" -- python main.py来运行我们的代码。
一般来说,运行上述指令是会报错的,报错的信息大概是pyglet requires an X server with GLX,主要原因在于显卡驱动以及cuda的安装有问题,没有加--no-opengl的flag。解决方案可以参考这里这里

3.2 保存每一段episode的录像

wrappers.Monitor默认不会保存所有episode的录像,但我们可以通过以下代码来设置保存所有episode的录像:
env = wrappers.Monitor(env, 'video', video_callable=lambda episode_id: True)

3.3 动态修改episode的最大step

env._max_episode_steps = xxx。注意,这仅当env的类型为TimeLimit时可用。

3.4 关于wrapper

  • 相同的两个wrapper不能叠加(e.g., Monitor不可以和Monitor叠加,但是Monitor可以和TimeLimit叠加),否则会报double wrapper的错。
  • 在注册FooEnv时,加不加max_episode_steps=xxx会影响返回的Env的类型。假如加了,返回的是TimeLimit类型的wrapper;假如不加,返回的就是裸的FooEnv。
  • Monitor里面有两个recorder,一个是stat_recorder,用于保存数据(reward之类的);另一个是video_recorder,用于录像。Monitor会在每一次调用env.reset和env.step之后调用render

3.5 屏蔽log信息

# main.pyimport logging# suppress INFO level logging 'Making new env: ...'logging.getLogger('gym.envs.registration').setLevel(logging.WARNING)# suppress INFO level logging 'Starting new video recorder writing to ...'logging.getLogger('gym.monitoring.video_recorder').setLevel(logging.WARNING)# suppress INFO level logging 'Creating monitor directory ...'logging.getLogger('gym.wrappers.monitoring').setLevel(logging.WARNING)


本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“他山之石”历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存