import os

import gym
import torch as t

from .base_simulator import BaseSimulator


class Procgen(BaseSimulator):
    name = os.environ.get("SIMULATOR")

    n_actions = 15
    failed = 0.0
    max_steps = {
        "bigfish": 6000,
        "bossfight": 4000,
        "caveflyer": 1000,
        "chaser": 1000,
        "climber": 1000,
        "coinrun": 1000,
        "dodgeball": 1000,
        "fruitbot": 1000,
        "heist": 1000,
        "jumper": 1000,
        "leaper": 500,
        "maze": 500,
        "miner": 1000,
        "ninja": 1000,
        "plunder": 2000,
        "starpilot": 1000,
    }.get(os.environ.get("SIMULATOR"), 1000)

    ACTIONS = [
        ("LEFT", "DOWN"),
        ("LEFT",),
        ("LEFT", "UP"),
        ("DOWN",),
        (),
        ("UP",),
        ("RIGHT", "DOWN"),
        ("RIGHT",),
        ("RIGHT", "UP"),
        ("D",),
        ("A",),
        ("W",),
        ("S",),
        ("Q",),
        ("E",),
    ]

    def __init__(self):
        super().__init__()
        self.env = gym.make(f"procgen:procgen-{self.name}-v0")

    def reset(self):
        self.state = (self.env.reset(), False)
        self.timestep = 0
        return self.state

    def step(
        self,
        action,
    ):
        self.timestep += 1
        state, reward, terminal, info = self.env.step(action)
        self.state = (state, info["prev_level_complete"])
        return self.state, reward, terminal or self.timestep >= self.max_steps

    def render(
        self,
        *args,
        **kwargs,
    ):
        self.env.render()

    def is_solved(self):
        return self.state[1]

    def state_tensor(self):
        return t.tensor(self.state[0], dtype=t.float32).view(1, 64, 64, 3)
