import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import Adam
import torch.optim as optim

from utils.distributions import Categorical, DiagGaussian, Heatmap, MultiHeatmap
from utils.model import get_grid, Flatten, NNBase
from utils.gnn import GNN

import copy
import numpy as np
import cv2
from PIL import Image
import random

class ANS_Policy(nn.Module):

    def __init__(self, input_shape, **kwargs):
        super(ANS_Policy, self).__init__()

        # self.bias = 1 / (input_shape[1] / 8. * input_shape[2] / 8.)
        out_size = int(input_shape[1] / 8. * input_shape[2] / 8.)

        self.is_recurrent = False
        self.rec_state_size = 1
        self.output_size = 256

        hidden_size = 512

        Conv2d = nn.Conv2d

        self.actor = nn.Sequential(
            nn.Conv2d(9, 32, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.ReLU(),
            Flatten()
        )

        self.critic = nn.ModuleList(
            [
                nn.Linear(out_size * 32 + 8, hidden_size),
                nn.Linear(hidden_size, self.output_size),
                nn.Linear(self.output_size, 1),
                nn.Embedding(72, 8)
            ]
        )

        self.downscaling = 2
        self.train()

    def forward(self, inputs, rnn_hxs, masks, extras):
        x = self.actor(inputs)
        orientation_emb = self.critic[3](extras[:, -1]).squeeze(1)
        x = torch.cat((x, orientation_emb), 1)

        x = nn.ReLU()(self.critic[0](x))
        x = nn.ReLU()(self.critic[1](x))

        return self.critic[2](x).squeeze(-1), x, rnn_hxs

# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/model.py#L15
class RL_Policy(nn.Module):

    # def __init__(self, obs_shape, action_space, model_type='gconv', base_kwargs=None, lr=None, eps=None):
    def __init__(self, obs_shape, action_space, model_type='gconv', base_kwargs=None, lr=None, eps=None, \
                 num_robots=1, device=None, args=None):

        self.num_robots = num_robots
        self.device = device
        self.args =args
        
        super(RL_Policy, self).__init__()
        if base_kwargs is None:
            base_kwargs = {}

        if model_type == 'ans':
            self.network = ANS_Policy(obs_shape, **base_kwargs)
        elif model_type == 'gnn':
            # self.network = GNN(obs_shape, base_kwargs.get('num_gnn_layer') * ['self', 'cross'], base_kwargs.get('use_history'), base_kwargs.get('ablation'))
            self.network = GNN(obs_shape, base_kwargs.get('num_gnn_layer') * ['self', 'cross'], base_kwargs.get('use_history'), base_kwargs.get('ablation'), num_robots)
        else:
            raise NotImplementedError

        if model_type == 'gnn':
            assert action_space.__class__.__name__ == "Box"
            self.dist = MultiHeatmap()
        elif action_space.__class__.__name__ == "Box":
            num_outputs = action_space.shape[0]
            self.dist = DiagGaussian(self.network.output_size, num_outputs)
        else:
            raise NotImplementedError

        self.actor_optimizer = optim.Adam(set(filter(lambda p: p.requires_grad,
            self.network.actor.parameters())).union(filter(lambda p: p.requires_grad,
            self.dist.parameters())), lr=lr[0], eps=eps)
        self.critic_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
            self.network.critic.parameters()), lr=lr[0] * lr[1], eps=eps)

        self.model_type = model_type

    @property
    def is_recurrent(self):
        return self.network.is_recurrent

    @property
    def rec_state_size(self):
        """Size of rnn_hx."""
        return self.network.rec_state_size

    @property
    def downscaling(self):
        return self.network.downscaling

    def forward(self, inputs, rnn_hxs, masks, extras):
        if extras is None:
            return self.network(inputs, rnn_hxs, masks)
        else:
            return self.network(inputs, rnn_hxs, masks, extras)

    def act(self, inputs, rnn_hxs, masks, extras=None, deterministic=False, 
            past_global_infos=None, dist_map=None, global_position=None, action_masked=0, noise_size=0):   

        value, actor_features, rnn_hxs = self(inputs, rnn_hxs, masks, extras)
        dist = self.dist(actor_features)
        
        if deterministic:
            action = dist.mode()
        else:
            if action_masked==3:
                dist_temp = copy.deepcopy(dist) 

                is_sample = torch.ones([len(actor_features), self.num_robots]).to(self.device)

                action_valid = torch.zeros([len(actor_features), self.num_robots]).long().to(self.device) - 1    
                past_global_infos = torch.from_numpy(np.array(past_global_infos)).to(self.device)

                threshold_area = 1          # alpha2,m^2
                threshold_goal = 10         # alpha3,pixel
                threshold_pos = 1           
                threshold_l_move_dist = 1   # alpha1

                threshold_near = 0.2    # beta_near
                threshold_far = 2       # beta_far

                threshold_resample = 20

                if self.args.threshold_noise:
                    print("threshold noise size = ", noise_size)

                    if True:    # use random noise
                        noise_size = random.uniform(-self.args.threshold_noise, self.args.threshold_noise)    
                        print("noise_size = ", noise_size)

                    threshold_area *= (1 + noise_size)
                    threshold_goal *= (1 + noise_size)
                    threshold_l_move_dist *= (1 + noise_size)

                    threshold_near *= (1 + noise_size)
                    threshold_far *= (1 + noise_size)

                num_probs = torch.zeros([len(actor_features), self.num_robots]).to(self.device)
                is_last_choice = torch.zeros([len(actor_features), self.num_robots]).int().to(self.device)
                resample_counters = torch.zeros([len(actor_features), self.num_robots, 2]).long().to(self.device)

                while torch.sum(is_sample) > 0:
                    # action = dist.sample()
                    action = dist_temp.sample()
                    is_sample *= 0.
                    #-------------------------------------------------------
                    for e in range(action.shape[0]):
                        for i in range(action.shape[1]):
                            idx_frontier = action[e, i]
                            num_probs[e, i] = torch.where(dist_temp.dist[e].probs[i]>0)[0].shape[0]
                            if num_probs[e, i] > 1:
                                dist_temp.dist[e].probs[i, idx_frontier] = 0
                                is_last_choice[e, i] = False
                            else:
                                is_last_choice[e, i] = True         
                    #-------------------------------------------------------

                    for e in range(action.shape[0]):
                        dist_map_maxpool = inputs[e, 8:8+self.num_robots, :, :]
                        frontier_idx = torch.nonzero(inputs[e, 1, :, :]).cpu().numpy()  # in [120,120]

                        if dist_map[(dist_map >= threshold_near)*(dist_map <= threshold_far)].shape[0] > 0:
                            threshold_near = 0.2
                            threshold_far = 2
                            pass
                        elif dist_map[(dist_map <= threshold_far)].shape[0] > 0:
                            threshold_near = dist_map.min().item()
                            threshold_far = 2
                            pass
                        else:
                            threshold_near = dist_map.min().item()
                            threshold_far = dist_map.max().item()
                            pass

                        for i in range(self.num_robots):
                            next_goal_maxpool = frontier_idx[action[e, i]]
                            action_dist_maxpool = dist_map_maxpool[i, next_goal_maxpool[0], next_goal_maxpool[1]]

                            ds = 4 * self.downscaling
                            global_position_npy = global_position.view(action.shape[0], self.num_robots, -1)[:, :, [2, 4]].cpu().numpy()
                            global_goal = [*(frontier_idx[action[e, i]] * ds + ds // 2 - global_position_npy[e, i])]
                            action_dist_123 = dist_map[e, i, global_goal[0], global_goal[1]]

                            next_goal = torch.Tensor(global_goal).unsqueeze(0).double().to(self.device)
                            action_dist = copy.deepcopy(action_dist_maxpool)

                            if action_valid[e, i] == -1:
                                action_valid[e, i] = action[e, i]

                                if action_dist >= threshold_near and action_dist <= threshold_far:
                                    is_in_range = True
                                else:
                                    is_in_range = False

                                if len(past_global_infos) > 2:
                                    past_goals = past_global_infos[:, e, i, 0:2]    
                                    past_start_pos = past_global_infos[:, e, i, 3:5]
                                    past_end_pos = past_global_infos[:, e, i, 10:12]
                                    past_areas = past_global_infos[:, e, i, 17]
                                    past_l_move_dists = past_global_infos[:, e, i, 18]

                                    goals_dist = []         
                                    for m, g1 in enumerate(past_goals):
                                        for n, g2 in enumerate(past_goals):
                                            if (m != n):                                                    
                                                g_dist = torch.sqrt(torch.sum(pow((g1 - g2), 2))).unsqueeze(0)
                                                goals_dist.append(g_dist)
                                    goals_dist = torch.cat(goals_dist, dim=0).to(self.device)
                                    goals_dist = torch.unique(goals_dist, dim=0)

                                    next_goal_dists = []    
                                    for g in past_goals:
                                        g = g.unsqueeze(0)
                                        a_dist = torch.sqrt(torch.sum(pow((g - next_goal), 2), dim=1))
                                        next_goal_dists.append(a_dist)
                                    next_goal_dists = torch.cat(next_goal_dists, dim=0).to(self.device)

                                    pos_dist = torch.sum(torch.sqrt(torch.sum(pow((past_start_pos - past_end_pos), 2), dim=1))) 
                                    area_sum = torch.sum(past_areas, dim=0) 
                                    l_move_dists_sum = torch.sum(past_l_move_dists, dim=0)

                                    if area_sum < threshold_area and \
                                        pos_dist < threshold_pos and \
                                        (goals_dist < threshold_goal).any() and \
                                        l_move_dists_sum < threshold_l_move_dist:  
                                        is_stacked = True   
                                    else:
                                        is_stacked = False

                                    if (next_goal_dists < threshold_goal).any():
                                        is_repeat_goal = True
                                    else:
                                        is_repeat_goal = False
                                else:
                                    is_stacked = False
                                    is_repeat_goal = False
                                    pass

                                ''' 0: Both two mask, 
                                    1: Only valid distance mask, 
                                    2: Only stuck mask'''
                                if self.args.action_mask_type==1:
                                    print("Only valid distance mask, no stuck mask")
                                    is_stacked = False
                                    is_repeat_goal = False
                                elif self.args.action_mask_type==2:
                                    print("Only stuck mask, no valid distance mask")
                                    is_in_range = True


                                if is_in_range==False:
                                    if resample_counters[e, i, 0] <= threshold_resample and \
                                        is_last_choice[e, i]==False :    
                                        # is_sample[e, i] = 1
                                        is_sample[e, i] = True
                                        action_valid[e, i] = -1 
                                        resample_counters[e, i, 0] += 1
                                    else:
                                        action_valid[e, i] = action[e, i]
                                        # is_sample[e, i] = 0
                                        is_sample[e, i] = False
                                        resample_counters[e, i, 0] *= 0

                                if is_stacked==True and is_repeat_goal==True:
                                    if resample_counters[e, i, 1] <= threshold_resample and \
                                        is_last_choice[e, i]==False:
                                        # is_sample = True
                                        is_sample[e, i] = 1
                                        action_valid[e, i] = -1 
                                        resample_counters[e, i, 1] += 1
                                    else:
                                        # print("resample_counters = {} ".format(resample_counters))
                                        action_valid[e, i] = action[e, i]
                                        # is_sample = False
                                        is_sample[e, i] = 0
                                        resample_counters[e, i, 1] *= 0
                                
                if (action_valid==-1).any():
                    print("action_valid = ", action_valid)
                    print("is_in_range = {}".format(is_in_range))
                    print("is_stacked = {}".format(is_stacked))
                    print("is_repeat_goal = {}".format(is_repeat_goal))
                    print("is_sample = {}".format(is_sample))
                    print("resample_counters=", resample_counters)
                    
                action = copy.deepcopy(action_valid) 
            else:
                action = dist.sample()

        action_log_probs = dist.log_probs(action)

        return value, action, action_log_probs, rnn_hxs, actor_features

    def get_value(self, inputs, rnn_hxs, masks, extras=None):
        value, actor_features, _ = self(inputs, rnn_hxs, masks, extras)
        return value, actor_features

    def evaluate_actions(self, inputs, rnn_hxs, masks, action, extras=None):

        value, actor_features, rnn_hxs = self(inputs, rnn_hxs, masks, extras)
        dist = self.dist(actor_features)

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        return value, action_log_probs, dist_entropy, rnn_hxs, actor_features

    def load(self, path, device):
        self.actor_optimizer = optim.Adam(set(filter(lambda p: p.requires_grad,
            self.network.actor.parameters())).union(filter(lambda p: p.requires_grad,
            self.dist.parameters())), lr=1e-3)
        self.critic_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
            self.network.critic.parameters()), lr=1e-3)
        # state_dict = torch.load(path, map_location=lambda storage, loc: storage)
        state_dict = torch.load(path, map_location=device)
        self.network.load_state_dict(state_dict['network'])
        self.actor_optimizer.load_state_dict(state_dict['actor_optimizer'])
        self.critic_optimizer.load_state_dict(state_dict['critic_optimizer'])
        del state_dict

    def load_critic(self, path, device):
        state_dict = torch.load(path, map_location=device)['network']
        self.network.critic.load_state_dict({k.replace('critic.', ''):v for k,v in state_dict.items() if 'critic' in k})
        # self.network.actor.load_state_dict({k.replace('actor.', ''):v for k,v in state_dict.items() if 'actor' in k})
        del state_dict

    def save(self, path):
        state = {
            'network': self.network.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict(),
        }

        if torch.__version__ <= "1.2.0":
            torch.save(state, path)
        else: 
            torch.save(state, path, _use_new_zipfile_serialization=False)
