import torch
import torch.nn as nn
import numpy as np
from utilities.util import *
from models.model import Model
from collections import namedtuple



class SECA(Model):

    def __init__(self, args, target_net=None):
        super(SECA, self).__init__(args)
        self.construct_model()
        self.apply(self.init_weights)
        if target_net != None:
            self.target_net = target_net
            self.reload_params_to_target()
        self.Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done', 'last_step'))
        self.build_act_all()
        self.order = self.args.order
    
    def build_act_all(self):
        act = torch.eye(self.act_dim)
        for j in range(self.n_ - 1):
            act_list = []
            for i in range(self.act_dim):
                act_one_hot = torch.nn.functional.one_hot(torch.tensor(i), self.act_dim)
                act_one_hot = act_one_hot.float().view(-1, self.act_dim).repeat(self.act_dim**(j+1), 1)
                act_list.append(torch.cat((act_one_hot, act), dim=-1))
            act = torch.cat(act_list, dim=0)
        self.act_all = cuda_wrapper(act, self.cuda_)


    def construct_policy_net(self):
        action_dicts = []
        if self.args.shared_parameters:
            l1 = nn.Linear(self.obs_dim, self.hid_dim)
            l2 = nn.Linear(self.hid_dim, self.hid_dim)
            a = nn.Linear(self.hid_dim, self.act_dim)
            for i in range(self.n_):
                action_dicts.append(nn.ModuleDict( {'layer_1': l1,\
                                                    'layer_2': l2,\
                                                    'action_head': a
                                                    }
                                                 )
                                   )
        else:
            for i in range(self.n_):
                action_dicts.append(nn.ModuleDict( {'layer_1': nn.Linear(self.obs_dim, self.hid_dim),\
                                                    'layer_2': nn.Linear(self.hid_dim, self.hid_dim),\
                                                    'action_head': nn.Linear(self.hid_dim, self.act_dim)
                                                    }
                                                  )
                                   )
        self.action_dicts = nn.ModuleList(action_dicts)

    def construct_value_net(self):
        # TODO
        value_dicts = []
        l1 = nn.Linear((self.n_)*self.obs_dim+(self.n_)*self.act_dim, self.hid_dim)
        l2 = nn.Linear(self.hid_dim, self.hid_dim)
        v = nn.Linear(self.hid_dim, 1)
        for i in range(self.n_):
            value_dicts.append(nn.ModuleDict( {'layer_1': l1,\
                                                'layer_2': l2,\
                                                'value_head': v
                                                }
                                            )
                                )
        self.value_dicts = nn.ModuleList(value_dicts)

    def construct_model(self):
        self.construct_value_net()
        self.construct_policy_net()

    def policy(self, obs, schedule=None, last_act=None, last_hid=None, info={}, stat={}):
        actions = []
        for i in range(self.n_):
            h = torch.relu( self.action_dicts[i]['layer_1'](obs[:, i, :]) )
            h = torch.relu( self.action_dicts[i]['layer_2'](h) )
            a = self.action_dicts[i]['action_head'](h)
            actions.append(a)
        actions = torch.stack(actions, dim=1)
        return actions

    def value(self, obs, act):
        batch_size = obs.size(0)
        obs = obs.view(batch_size, -1)
        act = act.view(batch_size, -1)
        inputs = torch.cat((obs, act),dim=-1).unsqueeze(1).expand(-1, self.n_, -1)
        values = []
        for i in range(self.n_):
            h = torch.relu(self.value_dicts[i]['layer_1'](inputs[:, i, :]))
            h = torch.relu(self.value_dicts[i]['layer_2'](h))
            v = self.value_dicts[i]['value_head'](h)
            values.append(v)
        values = torch.stack(values, dim=1)
        return values

    def cal_values(self, obs):
        batch_size = obs.size(0)
        a_n = self.act_all.size(0)
        obs_ = obs.view(batch_size, -1)
        obs_ = obs_.unsqueeze(1).expand(batch_size, a_n, self.n_*self.obs_dim)
        act_ = self.act_all.unsqueeze(0).expand(batch_size, a_n, self.n_*self.act_dim)
        h = torch.relu(self.value_dicts[0]['layer_1'](torch.cat((obs_, act_),dim=-1)))
        h = torch.relu(self.value_dicts[0]['layer_2'](h))
        values = self.value_dicts[0]['value_head'](h)
        return values
    
    def cal_prob(self, probs):
        batch_size = probs.size(0)
        view_list = [1] * self.n_
        act_dim_list = [self.act_dim] * self.n_
        final_prob = cuda_wrapper(torch.ones([batch_size] + act_dim_list, dtype=torch.float), self.cuda_)
        for i in range(self.n_):
            prob = probs[:, i, :]
            view_list[i] = -1
            new_prob = prob.view([batch_size] + view_list)
            view_list[i] = 1
            new_prob = new_prob.expand([batch_size] + act_dim_list)
            final_prob *= new_prob
        return final_prob.detach()   

    def get_loss(self, batch):
        batch_size = len(batch.state)
        rewards, last_step, done, actions, state, next_state = self.unpack_data(batch)
        action_out = self.policy(state)
        values = self.value(state, actions)
        action_prob = torch.softmax(action_out, dim=-1)
        values = values.squeeze(-1)
        if self.args.target:
            next_action_out = self.target_net.policy(next_state, last_act=actions)
        else:
            next_action_out = self.policy(next_state, last_act=actions)

        next_actions = select_action(self.args, next_action_out, status='train',  exploration=False)
        if self.args.target:
            next_values = self.target_net.value(next_state, next_actions)
        else:
            next_values = self.value(next_state, next_actions)
        next_values = next_values.squeeze(-1)

        returns = cuda_wrapper(torch.zeros((batch_size, self.n_), dtype=torch.float), self.cuda_)

        assert values.size() == next_values.size()
        assert returns.size() == values.size()
        for i in reversed(range(rewards.size(0))):
            if last_step[i]:
                next_return = 0 if done[i] else next_values[i].detach()
            else:
                next_return = next_values[i].detach()
            returns[i] = rewards[i] + self.args.gamma * next_return
        # value loss
        deltas = returns - values
        value_loss = deltas.pow(2).mean(dim=0)
        # actio loss
        value_all = self.cal_values(state)
        value_all = value_all.view([batch_size] + [self.act_dim]*self.n_).detach()
        # calculate the advantages
        tmp = action_prob.clone()
        F1_prob = self.cal_prob(action_prob)
        adv = [None] * self.n_
        for i in self.order:
            tmp[:, i, :] = actions[:, i, :]
            F2_prob = self.cal_prob(tmp)
            adv_tmp = (F2_prob - F1_prob) * value_all
            adv[i] = adv_tmp
            tmp = tmp.clone()
            F1_prob = F2_prob.clone()
        advantages = torch.stack(adv, dim=1)
        for _ in range(self.n_):
            advantages = advantages.sum(dim=-1)
        advantages = advantages.detach()
        if self.args.normalize_advantages:
            advantages = batchnorm(advantages)
        log_prob = multinomials_log_density(actions, action_out).contiguous().view(-1, self.n_)
        assert log_prob.size() == advantages.size()
        action_loss = - advantages * log_prob
        action_loss = action_loss.mean(dim=0)
        return action_loss, value_loss, action_out
