import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from acktr.distributions import Bernoulli, Categorical, DiagGaussian
from acktr.utils import init
import config
import sys
from acktr.vae import VanillaVAE, Decoder
from torch.distributions import Normal
sys.path.append('../')

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class Policy(nn.Module):
    def __init__(self, obs_shape, action_space, base=None, base_kwargs=None):
        super(Policy, self).__init__()
        if base_kwargs is None:
            base_kwargs = {}
        if base is None:
            if obs_shape is None:
                base = MLPBase
                self.base = base(100, **base_kwargs)
                if action_space.__class__.__name__ == "Discrete":
                    self.dist = Categorical(self.base.output_size, 100)
            elif len(obs_shape) == 3:
                base = CNNBase
                self.base = base(obs_shape[0], **base_kwargs)
            elif len(obs_shape) == 1:
                print("Base = CNNPro")
                base = CNNPro
                self.base = base(obs_shape[0], **base_kwargs)
                if action_space.__class__.__name__ == "Discrete":
                    num_outputs = action_space.n
                    self.dist = Categorical(self.base.output_size, num_outputs)
                    # self.dist_meta = Categorical(self.base.output_size, num_outputs)
            else:
                raise NotImplementedError
        
        # if action_space.__class__.__name__ == "Discrete":
        #     num_outputs = action_space.n
        #     self.dist = Categorical(self.base.output_size, num_outputs)
        # elif action_space.__class__.__name__ == "Box":
        #     num_outputs = action_space.shape[0]
        #     self.dist = DiagGaussian(self.base.output_size, num_outputs)
        # elif action_space.__class__.__name__ == "MultiBinary":
        #     num_outputs = action_space.shape[0]
        #     self.dist = Bernoulli(self.base.output_size, num_outputs)
        else:
            raise NotImplementedError

    @property
    def is_recurrent(self):
        return self.base.is_recurrent

    @property
    def recurrent_hidden_state_size(self):
        """Size of rnn_hx."""
        return self.base.recurrent_hidden_state_size

    def forward(self, inputs, rnn_hxs, masks):
        raise NotImplementedError

    def binary(self,input):
        a = torch.ones_like(input)
        b = torch.zeros_like(input)
        output = torch.where(input >= 0.5, a, b)
        return output

    def act(self, inputs, hmap, rnn_hxs, masks, location_masks, box, deterministic=False):
        # value, actor_features, rnn_hxs, graph = self.base(inputs, rnn_hxs, masks)
        # for n, p in self.base.actor.named_parameters():
        #     print(n, p)
        value, actor_features, rnn_hxs, v_pre, env_loss = self.base(inputs, hmap, rnn_hxs, masks, box)
        # filt_1 = torch.ones(*(location_masks.size()),dtype=torch.float32).cuda()
        # filt_0 = torch.zeros(*(location_masks.size()),dtype=torch.float32).cuda()
        location_masks = location_masks/100
        
        dist, bad_prob, mask_dist = self.dist(actor_features, location_masks)
        
        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)
        
        return value, action, action_log_probs, rnn_hxs, env_loss

    def act_sub(self, inputs, hmap, rnn_hxs, masks, location_masks, deterministic=False):
            # hmap_sub = hmap[]
        # value, actor_features, rnn_hxs, graph = self.base(inputs, rnn_hxs, masks)
        value, actor_features, rnn_hxs = self.base(inputs, hmap, rnn_hxs, masks)
        # filt_1 = torch.ones(*(location_masks.size()),dtype=torch.float32).cuda()
        # filt_0 = torch.zeros(*(location_masks.size()),dtype=torch.float32).cuda()
        # location_masks = torch.where(location_masks >= filt_1, filt_1, filt_0)
        location_masks = location_masks/100
        dist, bad_prob, mask_dist = self.dist(actor_features, location_masks)
        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()
        action_log_probs = dist.log_probs(action)
        return value, action, action_log_probs, rnn_hxs

    # def act_indepen(self, inputs, rnn_hxs, masks, deterministic=False):
    #     value, actor_features, rnn_hxs, graph = self.base(inputs, rnn_hxs, masks)
    #     pred_mask = self.binary(graph)
    #     dist,_ = self.dist(actor_features, pred_mask)
    #     if deterministic:
    #         action = dist.mode()
    #     else:
    #         action = dist.sample()
    #     action_log_probs = dist.log_probs(action)
    #     return value, action, action_log_probs, pred_mask

    def get_value(self, inputs, hmap, rnn_hxs, masks, box):
        # value, _, _ ,_= self.base(inputs, rnn_hxs, masks)
        value, _, _, v_obs, env_loss = self.base(inputs, hmap, rnn_hxs, masks, box)
        return value

    def get_value_sub(self, inputs, hmap, rnn_hxs, masks):
        # value, _, _ ,_= self.base(inputs, rnn_hxs, masks)
        value, _, _ = self.base(inputs, hmap, rnn_hxs, masks)
        return value

    # def get_policy_distribution(self, inputs, hmap, rnn_hxs, masks):
    #     value, actor_features, rnn_hxs, v_obs, env_loss = self.base(inputs, hmap, 0, 0, box)
    #     distribution = self.dist.get_policy_distribution(actor_features)
    #     return distribution
    
    def evaluate_actions(self, inputs, hmap, rnn_hxs, masks, action, location_masks, box):
        # value, actor_features, rnn_hxs, graph = self.base(inputs, rnn_hxs, masks)
        value, actor_features, rnn_hxs, v2, env_loss = self.base(inputs, hmap, rnn_hxs, masks, box)
        location_masks = location_masks/100
        dist, bad_prob, mask_dist= self.dist(actor_features, location_masks)
        dist_meta = dist 
        action_log_probs = dist.log_probs(action)
        
        dist_entropy = dist.entropy().mean()
        # return value, action_log_probs, dist_entropy, rnn_hxs, bad_prob, graph
        return v2, action_log_probs, dist_entropy, dist, bad_prob, env_loss
        
    def evaluate_actions_sub(self, inputs, hmap, rnn_hxs, masks, action, location_masks):
        # value, actor_features, rnn_hxs, graph = self.base(inputs, rnn_hxs, masks)
        value, actor_features, rnn_hxs = self.base(inputs, hmap, rnn_hxs, masks)
        location_masks = location_masks/100
        dist, bad_prob, mask_dist= self.dist(actor_features, location_masks)
        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()
        return value, action_log_probs, dist_entropy, rnn_hxs, bad_prob

    # def evaluate_actions_indepen(self, inputs, rnn_hxs, masks, action):
    #     value, actor_features, _, graph = self.base(inputs, rnn_hxs, masks)
    #     pred_mask = self.binary(graph)
    #     dist, _ = self.dist(actor_features, pred_mask)
    #     action_log_probs = dist.log_probs(action)
    #     dist_entropy = dist.entropy().mean()
    #     return value, action_log_probs, dist_entropy, graph


class NNBase(nn.Module):
    def __init__(self, recurrent, recurrent_input_size, hidden_size):
        super(NNBase, self).__init__()

        self._hidden_size = hidden_size
        self._recurrent = recurrent

        if recurrent:
            self.gru = nn.GRU(recurrent_input_size, hidden_size)
            for name, param in self.gru.named_parameters():
                if 'bias' in name:
                    nn.init.constant_(param, 0)
                elif 'weight' in name:
                    nn.init.orthogonal_(param)

    @property
    def is_recurrent(self):
        return self._recurrent

    @property
    def recurrent_hidden_state_size(self):
        if self._recurrent:
            return self._hidden_size
        return 1

    @property
    def output_size(self):
        return self._hidden_size
    def _forward_gru(self, x, hxs, masks):
        if x.size(0) == hxs.size(0):
            x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0))
            x = x.squeeze(0)
            hxs = hxs.squeeze(0)
        else:
            # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
            N = hxs.size(0)
            T = int(x.size(0) / N)

            # unflatten
            x = x.view(T, N, x.size(1))

            # Same deal with masks
            masks = masks.view(T, N)

            # Let's figure out which steps in the sequence have a zero for any agent
            # We will always assume t=0 has a zero in it as that makes the logic cleaner
            has_zeros = ((masks[1:] == 0.0) \
                            .any(dim=-1)
                            .nonzero()
                            .squeeze()
                            .cpu())

            # +1 to correct the masks[1:]
            if has_zeros.dim() == 0:
                # Deal with scalar
                has_zeros = [has_zeros.item() + 1]
            else:
                has_zeros = (has_zeros + 1).numpy().tolist()

            # add t=0 and t=T to the list
            has_zeros = [0] + has_zeros + [T]

            hxs = hxs.unsqueeze(0)
            outputs = []
            for i in range(len(has_zeros) - 1):
                # We can now process steps that don't have any zeros in masks together!
                # This is much faster
                start_idx = has_zeros[i]
                end_idx = has_zeros[i + 1]

                rnn_scores, hxs = self.gru(
                    x[start_idx:end_idx],
                    hxs * masks[start_idx].view(1, -1, 1))

                outputs.append(rnn_scores)

            # assert len(outputs) == T
            # x is a (T, N, -1) tensor
            x = torch.cat(outputs, dim=0)
            # flatten
            x = x.view(T * N, -1)
            hxs = hxs.squeeze(0)

        return x, hxs

class CNNBase(NNBase):
    def __init__(self, num_inputs, recurrent=False, hidden_size=512):
        super(CNNBase, self).__init__(recurrent, hidden_size, hidden_size)

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.main = nn.Sequential(
            init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), nn.ReLU(),
            init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(),
            init_(nn.Conv2d(64, 32, 3, stride=1)), nn.ReLU(), Flatten(),
            init_(nn.Linear(32 * 7 * 7, hidden_size)), nn.ReLU())

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        self.critic_linear = init_(nn.Linear(hidden_size, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks):
        x = self.main(inputs / 255.0)

        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        return self.critic_linear(x), x, rnn_hxs

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class MLPBase(NNBase):
    def __init__(self, num_inputs, recurrent=False, hidden_size=256):
        super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size)

        if recurrent:
            num_inputs = hidden_size

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), nn.init.calculate_gain('relu'))
        
        self.share = nn.Sequential(
            init_(nn.Conv2d(config.channel+1, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
        )
        pred_len = 100
        pred_len = pred_len * 2

        self.actor = nn.Sequential(
            init_(nn.Conv2d(64, 8, 1, stride=1)),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(8*100, hidden_size)),
            nn.ReLU(),
        )

        self.critic = nn.Sequential(
            init_(nn.Conv2d(64, 4, 1, stride=1)),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(4*100, hidden_size)),
            nn.ReLU(),
        )
        self.critic_linear = init_(nn.Linear(hidden_size, 1))
        self.train()
        
    def forward(self, inputs, hmap, rnn_hxs, masks):
        hmap = hmap.reshape((-1,config.channel,10,10))*0.1
        # hmap = torch.nn.functional.normalize(hmap, p=1, dim=1)
        x = inputs.reshape((-1,1,10,10))*0.1
        # x = torch.nn.functional.normalize(x, p=1, dim=1)

        inputs=torch.cat((hmap,x),1)
        x = self.share(inputs)
        hidden_critic = self.critic(x)
        hidden_actor = self.actor(x)
        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs
        
class CNNPro(NNBase):
    def __init__(self, num_inputs, recurrent=False, hidden_size=256):

        super(CNNPro, self).__init__(recurrent, num_inputs, hidden_size)
        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), nn.init.calculate_gain('relu'))
        
        self.l1 = nn.Sequential(
            init_(nn.Conv2d(config.channel, 64, 5, stride=2, padding=2)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 5, stride=5, padding=2)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
        )
        self.l2 = nn.Sequential(
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
        )
        self.l3 = nn.Sequential(
            nn.ReLU(),
            init_(nn.Conv2d(64, 1, 3, stride=1, padding=1)),
            nn.ReLU(),
        )
        self.ca = ChannelAttention(64)
        self.sa = SpatialAttention()
        
        # self.cnn_hmap = nn.Sequential(
        #     # nn.MaxPool2d(kernel_size=10,stride=10),
        #     init_(nn.Conv2d(64, 64, 5, stride=2, padding=2)),
        #     nn.ReLU(),
        #     init_(nn.Conv2d(64, 64, 5, stride=5, padding=2)),
        #     nn.ReLU(),
        #     init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
        #     nn.ReLU(),
        #     init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
        #     nn.ReLU(),
        # )

        self.share = nn.Sequential(
            init_(nn.Conv2d(2+config.enable_rotation, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
            init_(nn.Conv2d(64, 64, 3, stride=1, padding=1)),
            nn.ReLU(),
        )
        pred_len = config.container_size[0] * config.container_size[1]
        if config.enable_rotation:
            pred_len = pred_len * 2
            
        # self.mask = nn.Sequential(
        #     init_(nn.Conv2d(64, 8, 1, stride=1)),
        #     nn.ReLU(),
        #     Flatten(),
        #     init_(nn.Linear(8*config.pallet_size*config.pallet_size, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, pred_len)),
        #     nn.ReLU(),
        #     # nn.Sigmoid(),
        # )

        self.actor = nn.Sequential(
            init_(nn.Conv2d(64, 8, 1, stride=1)),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(8*config.container_size[0]*config.container_size[1], hidden_size)),
            nn.ReLU(),
        )

        self.critic = nn.Sequential(
            init_(nn.Conv2d(64, 4, 1, stride=1)),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(4*config.container_size[0]*config.container_size[1], hidden_size)),
            nn.ReLU(),
        )
        self.critic_linear = init_(nn.Linear(hidden_size, 1))
        self.state_enc1 = VanillaVAE(5, 256)
        self.state_enc2 = VanillaVAE(5, 256)
        # self.state_dec = Decoder(256)

        self.rnn = nn.LSTM(input_size=4, hidden_size=100*4, num_layers=1, batch_first=True)#(input_size,hidden_size,num_layers)
        # self.linear_rnn = nn.Linear(hidden_size=64, output_size=1)

        self.train()


    def loss_function(self, v1, v2, mu1, mu2, log_var1, log_var2):
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        kld_weight = 0.00025 # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(v1, v2)
        mu_dis_loss =F.mse_loss(mu1, mu2)
        var_dis_loss =F.mse_loss(log_var1, log_var2)
        kld_loss1 = -0.5 * torch.sum(1 + log_var1 - mu1 ** 2 - log_var1.exp(), dim = 1)
        kld_loss2 = -0.5 * torch.sum(1 + log_var2 - mu2 ** 2 - log_var2.exp(), dim = 1)
        kld_loss = torch.mean( torch.cat((kld_loss1,kld_loss2),0),dim = 0)
        loss = recons_loss + kld_weight * (kld_loss + mu_dis_loss + var_dis_loss)
        return loss
    
    # def forward(self, inputs, hmap, rnn_hxs, masks, box): #recurrent, num_inputs, hidden_size
    #     # print('box size',box.size())
    #     # out_box, _ = self.rnn(box)
    #     # out_box = out_box.reshape((-1, config.container_size[0]*10, 4, config.container_size[1]*10))
    #     # out_box = out_box.transpose(1,2)
    #     # reshape hmap and mask
    #     hmap = hmap.reshape((-1,config.channel,config.container_size[0]*10,config.container_size[1]*10))*0.1
    #     # s1 = hmap[:,0:5,:,:]
    #     # s2 = torch.cat((hmap[:,0:1,:,:],out_box),1) #hmap[:,4:7,:,:])
    #     # hmap = torch.nn.functional.normalize(hmap, p=2, dim=1)
    #     x = inputs.reshape((-1,1+config.enable_rotation,config.container_size[0],config.container_size[1]))*0.1
    #     # x = torch.nn.functional.normalize(x, p=1, dim=1)

    #     # critic vae encoder and loss
    #     # z1 = self.state_enc1(s1) # z = z, mu, log_var
    #     # z2 = self.state_enc2(s2)
        
    #     # actor att-CNN
    #     hmap = self.l1(hmap)
    #     x_att = self.ca(hmap)*hmap
    #     hmap = self.sa(x_att)*x_att+hmap

    #     hmap = self.l2(hmap)
    #     x_att = self.ca(hmap)*hmap
    #     hmap = self.sa(x_att)*x_att+hmap
        
    #     hmap = self.l3(hmap)       
        
    #     inputs=torch.cat((hmap,x),1)
    #     assert not self.is_recurrent
    #     share = self.share(inputs)
    #     hidden_critic = self.critic(share)
    #     hidden_actor = self.actor(share)
    #     # pred_mask = self.mask(share)
    #     cl = self.critic_linear(hidden_critic)
    #     # v_pre = self.critic_linear(z2[0])
    #     # v_obs = self.critic_linear(z1[0])
    #     # env_loss = self.loss_function(v_pre,v_obs,z1[1],z2[1],z1[2],z2[2])
    #     env_loss = 0 

    #     return cl, hidden_actor, rnn_hxs, cl, env_loss#, pred_mask
    
    def forward(self, inputs, hmap, rnn_hxs, masks, box): #recurrent, num_inputs, hidden_size
        # print('box size',box.size())
        out_box, _ = self.rnn(box)
        # print('out_box.size()',out_box.size())
        out_box = out_box.reshape((-1, config.container_size[0]*10, 4, config.container_size[1]*10))
        out_box = out_box.transpose(1,2)
        # reshape hmap and mask
        hmap = hmap.reshape((-1,config.channel,config.container_size[0]*10,config.container_size[1]*10))*0.1
        s1 = hmap[:,0:5,:,:]
        s2 = torch.cat((hmap[:,0:1,:,:],out_box),1) #hmap[:,4:7,:,:])
        # hmap = torch.nn.functional.normalize(hmap, p=2, dim=1)
        x = inputs.reshape((-1,1+config.enable_rotation,config.container_size[0],config.container_size[1]))*0.1
        # x = torch.nn.functional.normalize(x, p=1, dim=1)

        # critic vae encoder and loss
        z1 = self.state_enc1(s1) # z = z, mu, log_var
        z2 = self.state_enc2(s2)
        
        # actor att-CNN
        hmap = self.l1(hmap)
        x_att = self.ca(hmap)*hmap
        hmap = self.sa(x_att)*x_att+hmap

        hmap = self.l2(hmap)
        x_att = self.ca(hmap)*hmap
        hmap = self.sa(x_att)*x_att+hmap
        # hmap = self.l2(hmap)
        # x_att = self.ca(hmap)*hmap
        # hmap = self.sa(x_att)*x_att+hmap
        
        hmap = self.l3(hmap)       
        
        inputs=torch.cat((hmap,x),1)
        assert not self.is_recurrent
        share = self.share(inputs)
        # hidden_critic = self.critic(share)
        hidden_actor = self.actor(share)
        # pred_mask = self.mask(share)
        # cl = self.critic_linear(hidden_critic)
        # v_pre = self.q_value(z2[1],z2[2])
        # v_obs = self.q_value(z1[1],z1[2])
        v_pre = self.critic_linear(z2[0])
        v_obs = self.critic_linear(z1[0])
        env_loss = self.loss_function(v_pre,v_obs,z1[1],z2[1],z1[2],z2[2])

        return v_obs, hidden_actor, rnn_hxs, v_pre, env_loss#, pred_mask
