import torch
import numpy as np
import random
from minatar import Environment
from models.utils import qnet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class minatarenv:
    def __init__(self, envname, expert_policy_path):
        self.env = Environment(envname)
        self.name = envname
        self.MAX_STEPS = 1e5

        self.in_channels = self.env.state_shape()[2]
        self.num_actions_ = self.env.num_actions()

        self.expert_policy = qnet().to(device).eval()
        self.expert_policy.set_init({'env': self})
        checkpoint = torch.load(expert_policy_path, map_location=device)
        self.expert_policy.load_state_dict(checkpoint['policy_net_state_dict'])

    def create_dataset(self, data_collecting, s):

        if data_collecting == 'good':
            epsilon = 0.7
        elif data_collecting == 'mid':
            epsilon = 0.5
        elif data_collecting == 'bad':
            epsilon = 0.3
        else:
            raise NotImplementedError(f"data_collecting is {data_collecting}, it should be chosen from good, mid, bad")
        
        # with torch.no_grad():
        #     return self.expert_policy(s).max(1)[1].view(1, 1).cpu().detach().numpy()

        if np.random.binomial(1, epsilon) == 1:  # with epsilon, output random action
            action = np.array([[random.randrange(self.num_actions_)]])
        else:
            with torch.no_grad():
                action = self.expert_policy(s).max(1)[1].view(1, 1).cpu().detach().numpy()
        return action

    def get_state(self, s):
        return np.transpose(np.array(s), (2, 0, 1)).astype(np.float32)
        # return np.expand_dims(np.transpose(np.array(s), (2, 0, 1)), axis=0).astype(np.float32)
    
    def step(self, action):  
        reward, terminated = self.env.act(torch.from_numpy(action))
        s_prime = self.get_state(self.env.state())
        return s_prime, reward, terminated, None
    
    def reset(self):
        self.env.reset()
        s = self.get_state(self.env.state())
        return s, None
    
    def num_actions(self, ):
        return self.num_actions_

