import gc
import gym
import gzip
import gym.spaces
import json
import numpy as np
import os
import retro
import retro.data
from gym.utils import seeding
gym_version = tuple(int(x) for x in gym.__version__.split('.'))
__all__ = ['RetroEnv']
[docs]class RetroEnv(gym.Env):
"""
Gym Retro environment class
Provides a Gym interface to classic video games
"""
metadata = {'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 60.0}
def __init__(self, game, state=retro.State.DEFAULT, scenario=None, info=None, use_restricted_actions=retro.Actions.FILTERED,
record=False, players=1, inttype=retro.data.Integrations.STABLE, obs_type=retro.Observations.IMAGE):
if not hasattr(self, 'spec'):
self.spec = None
self._obs_type = obs_type
self.img = None
self.ram = None
self.viewer = None
self.gamename = game
self.statename = state
self.initial_state = None
self.players = players
metadata = {}
rom_path = retro.data.get_romfile_path(game, inttype)
metadata_path = retro.data.get_file_path(game, 'metadata.json', inttype)
if state == retro.State.NONE:
self.statename = None
elif state == retro.State.DEFAULT:
self.statename = None
try:
with open(metadata_path) as f:
metadata = json.load(f)
if 'default_player_state' in metadata and self.players <= len(metadata['default_player_state']):
self.statename = metadata['default_player_state'][self.players - 1]
elif 'default_state' in metadata:
self.statename = metadata['default_state']
else:
self.statename = None
except (IOError, json.JSONDecodeError):
pass
if self.statename:
self.load_state(self.statename, inttype)
self.data = retro.data.GameData()
if info is None:
info = 'data'
if info.endswith('.json'):
# assume it's a path
info_path = info
else:
info_path = retro.data.get_file_path(game, info + '.json', inttype)
if scenario is None:
scenario = 'scenario'
if scenario.endswith('.json'):
# assume it's a path
scenario_path = scenario
else:
scenario_path = retro.data.get_file_path(game, scenario + '.json', inttype)
self.system = retro.get_romfile_system(rom_path)
# We can't have more than one emulator per process. Before creating an
# emulator, ensure that unused ones are garbage-collected
gc.collect()
self.em = retro.RetroEmulator(rom_path)
self.em.configure_data(self.data)
self.em.step()
core = retro.get_system_info(self.system)
self.buttons = core['buttons']
self.num_buttons = len(self.buttons)
try:
assert self.data.load(info_path, scenario_path), 'Failed to load info (%s) or scenario (%s)' % (info_path, scenario_path)
except Exception:
del self.em
raise
self.button_combos = self.data.valid_actions()
if use_restricted_actions == retro.Actions.DISCRETE:
combos = 1
for combo in self.button_combos:
combos *= len(combo)
self.action_space = gym.spaces.Discrete(combos ** players)
elif use_restricted_actions == retro.Actions.MULTI_DISCRETE:
self.action_space = gym.spaces.MultiDiscrete([len(combos) if gym_version >= (0, 9, 6) else (0, len(combos) - 1) for combos in self.button_combos] * players)
else:
self.action_space = gym.spaces.MultiBinary(self.num_buttons * players)
kwargs = {}
if gym_version >= (0, 9, 6):
kwargs['dtype'] = np.uint8
if self._obs_type == retro.Observations.RAM:
shape = self.get_ram().shape
else:
img = [self.get_screen(p) for p in range(players)]
shape = img[0].shape
self.observation_space = gym.spaces.Box(low=0, high=255, shape=shape, **kwargs)
self.use_restricted_actions = use_restricted_actions
self.movie = None
self.movie_id = 0
self.movie_path = None
if record is True:
self.auto_record()
elif record is not False:
self.auto_record(record)
self.seed()
if gym_version < (0, 9, 6):
self._seed = self.seed
self._step = self.step
self._reset = self.reset
self._render = self.render
self._close = self.close
def _update_obs(self):
if self._obs_type == retro.Observations.RAM:
self.ram = self.get_ram()
return self.ram
elif self._obs_type == retro.Observations.IMAGE:
self.img = self.get_screen()
return self.img
else:
raise ValueError('Unrecognized observation type: {}'.format(self._obs_type))
def action_to_array(self, a):
actions = []
for p in range(self.players):
action = 0
if self.use_restricted_actions == retro.Actions.DISCRETE:
for combo in self.button_combos:
current = a % len(combo)
a //= len(combo)
action |= combo[current]
elif self.use_restricted_actions == retro.Actions.MULTI_DISCRETE:
ap = a[self.num_buttons * p:self.num_buttons * (p + 1)]
for i in range(len(ap)):
buttons = self.button_combos[i]
action |= buttons[ap[i]]
else:
ap = a[self.num_buttons * p:self.num_buttons * (p + 1)]
for i in range(len(ap)):
action |= int(ap[i]) << i
if self.use_restricted_actions == retro.Actions.FILTERED:
action = self.data.filter_action(action)
ap = np.zeros([self.num_buttons], np.uint8)
for i in range(self.num_buttons):
ap[i] = (action >> i) & 1
actions.append(ap)
return actions
def step(self, a):
if self.img is None and self.ram is None:
raise RuntimeError('Please call env.reset() before env.step()')
for p, ap in enumerate(self.action_to_array(a)):
if self.movie:
for i in range(self.num_buttons):
self.movie.set_key(i, ap[i], p)
self.em.set_button_mask(ap, p)
if self.movie:
self.movie.step()
self.em.step()
self.data.update_ram()
ob = self._update_obs()
rew, done, info = self.compute_step()
return ob, rew, bool(done), dict(info)
def reset(self):
if self.initial_state:
self.em.set_state(self.initial_state)
for p in range(self.players):
self.em.set_button_mask(np.zeros([self.num_buttons], np.uint8), p)
self.em.step()
if self.movie_path is not None:
rel_statename = os.path.splitext(os.path.basename(self.statename))[0]
self.record_movie(os.path.join(self.movie_path, '%s-%s-%06d.bk2' % (self.gamename, rel_statename, self.movie_id)))
self.movie_id += 1
if self.movie:
self.movie.step()
self.data.reset()
self.data.update_ram()
return self._update_obs()
def seed(self, seed=None):
self.np_random, seed1 = seeding.np_random(seed)
# Derive a random seed. This gets passed as a uint, but gets
# checked as an int elsewhere, so we need to keep it below
# 2**31.
seed2 = seeding.hash_seed(seed1 + 1) % 2**31
return [seed1, seed2]
def render(self, mode='human', close=False):
if close:
if self.viewer:
self.viewer.close()
return
img = self.get_screen() if self.img is None else self.img
if mode == "rgb_array":
return img
elif mode == "human":
if self.viewer is None:
from gym.envs.classic_control.rendering import SimpleImageViewer
self.viewer = SimpleImageViewer()
self.viewer.imshow(img)
return self.viewer.isopen
def close(self):
if hasattr(self, 'em'):
del self.em
def get_action_meaning(self, act):
actions = []
for p, action in enumerate(self.action_to_array(act)):
actions.append([self.buttons[i] for i in np.extract(action, np.arange(len(action)))])
if self.players == 1:
return actions[0]
return actions
def get_ram(self):
blocks = []
for offset in sorted(self.data.memory.blocks):
arr = np.frombuffer(self.data.memory.blocks[offset], dtype=np.uint8)
blocks.append(arr)
return np.concatenate(blocks)
def get_screen(self, player=0):
img = self.em.get_screen()
x, y, w, h = self.data.crop_info(player)
if not w or x + w > img.shape[1]:
w = img.shape[1]
else:
w += x
if not h or y + h > img.shape[0]:
h = img.shape[0]
else:
h += y
if x == 0 and y == 0 and w == img.shape[1] and h == img.shape[0]:
return img
return img[y:h, x:w]
def load_state(self, statename, inttype=retro.data.Integrations.DEFAULT):
if not statename.endswith('.state'):
statename += '.state'
with gzip.open(retro.data.get_file_path(self.gamename, statename, inttype), 'rb') as fh:
self.initial_state = fh.read()
self.statename = statename
def compute_step(self):
if self.players > 1:
reward = [self.data.current_reward(p) for p in range(self.players)]
else:
reward = self.data.current_reward()
done = self.data.is_done()
return reward, done, self.data.lookup_all()
def record_movie(self, path):
self.movie = retro.Movie(path, True, self.players)
self.movie.configure(self.gamename, self.em)
if self.initial_state:
self.movie.set_state(self.initial_state)
def stop_record(self):
self.movie_path = None
self.movie_id = 0
if self.movie:
self.movie.close()
self.movie = None
def auto_record(self, path=None):
if not path:
path = os.getcwd()
self.movie_path = path