import numpy as np
import torch
import gym
import copy
from acktr.model import Policy
import sys
sys.path.append('../')
import config
# from utils import get_mask_from_obs, get_rotmask_from_obs
from acktr.utils import get_mask_from_obs, get_rotmask_from_obs


class nnModel(object):
    def __init__(self, url, config, sub=False):
        area = config.container_size[0]*config.container_size[1]
        self.alen = area * (1+config.enable_rotation)
        self.olen = config.channel * area
        self.height = config.container_size[2]*10
        self.device = torch.device("cuda:" + str(config.device) if config.cuda else "cpu")
        torch.cuda.set_device(config.device)
        if sub:
            self._model = self._load_model_sub(url)
        else:
            self._model = self._load_model(url)
        self.box = torch.zeros(1,120, 4).to(self.device)
        self.config = config

    def _load_model(self, url):
        model_pretrained, ob_rms = torch.load(url)
        observation_space = gym.spaces.Box(low=0.0, high=self.height, shape=(self.olen, ))
        action_space = gym.spaces.Discrete(self.alen)
        actor_critic = Policy(obs_shape=observation_space.shape, action_space=action_space)
        print(actor_critic)

        load_dict = {k.replace('module.', ''): v for k, v in model_pretrained.items()}
        load_dict = {k.replace('add_bias.', ''): v for k, v in load_dict.items()}
        load_dict = {k.replace('_bias', 'bias'): v for k, v in load_dict.items()}

        for k, v in load_dict.items():
            print(k)
            # if len(v.size()) <= 3:
            #     load_dict[k] = v.squeeze(dim=-1)
        
        actor_critic.load_state_dict(load_dict)
        actor_critic = actor_critic.to(self.device)
        return actor_critic

    def _load_model_sub(self, url_sub):
        action_space = gym.spaces.Discrete(self.alen)
        model_pretrained_sub, ob_rms = torch.load(url_sub)
        actor_critic_sub = Policy(obs_shape=None, action_space=action_space)
        load_dict_sub = {k.replace('module.', ''): v for k, v in model_pretrained_sub.items()}
        load_dict_sub = {k.replace('add_bias.', ''): v for k, v in load_dict_sub.items()}
        load_dict_sub = {k.replace('_bias', 'bias'): v for k, v in load_dict_sub.items()}

        for k, v in load_dict_sub.items():
            print(k)
            # if len(v.size()) <= 3:
            #     load_dict_sub[k] = v.squeeze(dim=-1)

        actor_critic_sub.load_state_dict(load_dict_sub)
        actor_critic_sub = actor_critic_sub.to(self.device)

        return actor_critic_sub

    def evaluate(self, obs, hmap, use_mask=True):
        # def _get_mask_from_obs(observation):
        #     if not isinstance(observation, np.ndarray):
        #         box_info = observation.cpu().numpy()
        #     else:
        #         box_info = observation
        #     box_info = box_info.reshape((config.channel,-1))
        #     mask = box_info[config.channel-1].reshape((-1,)).tolist()
        #     return mask
        x = copy.deepcopy(obs)
        
        if config.enable_rotation:
            mask = get_rotmask_from_obs(1,obs)
        else:
            mask = get_mask_from_obs(1,obs)
        # filt = np.ones(len(mask))*0.1
        # mask = np.where(mask>=filt,mask,0)
        # mask = _get_mask_from_obs(obs)
        x = torch.FloatTensor(x).to(self.device)
        hmap = torch.FloatTensor(hmap).to(self.device)
        # print(hmap[1][0][0],hmap[2][0][0],hmap[3][0][0],hmap[4][0][0])
        # box = torch.zeros(1,120, 4).to(self.device)
        # value, logits, _, pred= self._model.base(x, 0, 0)#####
        value, logits, _, _, _= self._model.base(x, hmap, 0, 0, self.box)
        poss = self._model.dist.get_policy_distribution(logits)
        self.box[0,:-1] = self.box[0,:-1]
        self.box[0,-1]  = torch.Tensor([hmap[1][0][0],hmap[2][0][0],hmap[3][0][0],hmap[4][0][0]]).to(self.device)
        # pred = self._model.binary(pred)#####
        # pred = get_rotation_mask(torch.tensor(obs), [10,10,10])
        # pred = np.array(get_possible_position(torch.tensor(obs), [10,10,10]))

        value = float(value)
        poss = poss.cpu().detach().numpy()
        # pred = pred.cpu().detach().numpy()

        # np.set_printoptions(precision=3, suppress=True)
        # print('---------------------------')
        # print(pred1.reshape(10,10))
        # print(pred2.reshape(10,10))

        def softmax(x):
            probs = np.exp(x - np.max(x))
            probs /= np.sum(probs)
            return probs

        poss_in_actions = softmax(poss)
        # if use_mask:
        #     poss_in_actions = poss_in_actions * mask
        poss_in_actions = poss_in_actions * mask
        poss_in_actions = np.reshape(poss_in_actions, newshape=(-1,))
        return value, poss_in_actions

    def evaluate_sub(self, obs, hmap, use_mask=True):
        # def _get_mask_from_obs(observation):
        #     if not isinstance(observation, np.ndarray):
        #         box_info = observation.cpu().numpy()
        #     else:
        #         box_info = observation
        #     box_info = box_info.reshape((config.channel,-1))
        #     mask = box_info[config.channel-1].reshape((-1,)).tolist()
        #     return mask
        x = copy.deepcopy(obs)
        if config.enable_rotation:
            mask = get_rotmask_from_obs(1,obs)
        else:
            mask = get_mask_from_obs(1,obs)
        # filt = np.ones(len(mask))*0.01
        # mask = np.where(mask>=filt,mask,0)
        # mask = _get_mask_from_obs(obs)
        x = torch.FloatTensor(x).to(self.device)
        hmap = torch.FloatTensor(hmap).to(self.device)
        box = torch.zeros(1,100, 4).to(self.device)
        # value, logits, _, pred= self._model.base(x, 0, 0)#####
        value, logits, _= self._model.base(x, hmap, 0, 0)
        poss = self._model.dist.get_policy_distribution(logits)
        # pred = self._model.binary(pred)#####
        # pred = get_rotation_mask(torch.tensor(obs), [10,10,10])
        # pred = np.array(get_possible_position(torch.tensor(obs), [10,10,10]))

        value = float(value)
        poss = poss.cpu().detach().numpy()
        # pred = pred.cpu().detach().numpy()

        # np.set_printoptions(precision=3, suppress=True)
        # print('---------------------------')
        # print(pred1.reshape(10,10))
        # print(pred2.reshape(10,10))

        def softmax(x):
            probs = np.exp(x - np.max(x))
            probs /= np.sum(probs)
            return probs

        poss_in_actions = softmax(poss)
        # if use_mask:
        #     poss_in_actions = poss_in_actions * mask
        poss_in_actions = poss_in_actions * mask
        poss_in_actions = np.reshape(poss_in_actions, newshape=(-1,))
        return value, poss_in_actions


    def sample_action(self, obs):
        x = copy.deepcopy(obs)
        x = torch.FloatTensor(x).to(self.device)

        value, logits, _, pred= self._model.base(x, 0, 0)
        poss = self._model.dist.get_policy_distribution(logits)
        pred = self._model.binary(pred)

        value = float(value)
        cat = torch.distributions.Categorical(logits=poss+pred*7)
        action = int(cat.sample())

        return value, action


