import torch
from torch.optim import Adam
from gym.spaces import Box

from utils import onehot_from_logits, gumbel_softmax
from utils import soft_update, hard_update
from utils import MLPNetworks

## Here we only consider discrete action
MSELoss = torch.nn.MSELoss()

# from https://github.com/shariqiqbal2810/maddpg-pytorch/

class Agent(object):
    def __init__(self, n_p_i, n_p_o, n_c_i, hidden_dim=64,
                 lr=0.01, discrete_action=True):
        assert discrete_action == True
        self.s_pol = MLPNetworks(n_p_i, n_p_o, hidden_dim=hidden_dim, constrain_out=True, discrete_action=discrete_action)
        self.t_pol = MLPNetworks(n_p_i, n_p_o, hidden_dim=hidden_dim, constrain_out=True, discrete_action=discrete_action)
        hard_update(self.t_pol, self.s_pol)

        self.s_cri = MLPNetworks(n_c_i, 1, hidden_dim=hidden_dim, constrain_out=False)
        self.t_cri = MLPNetworks(n_c_i, 1, hidden_dim=hidden_dim, constrain_out=False)
        hard_update(self.t_cri, self.s_cri)

        self.pol_opt = Adam(self.s_pol.parameters(), lr=lr)
        self.cri_opt = Adam(self.s_cri.parameters(), lr=lr)

        self.exploration = 0.3
        self.discrete_action = discrete_action

    def scale_noise(self, scale):
        self.exploration = scale

    def step(self, obs, explore=False):
        action = self.s_pol(obs)
        if explore:     # only explore for the data collection
            action = gumbel_softmax(action, 'cpu', hard=True)
        else:
            action = onehot_from_logits(action)
        return action

    def get_params(self):
        return {'s_pol':    self.s_pol.state_dict(),
                's_cri':    self.s_cri.state_dict(),
                't_pol':    self.t_pol.state_dict(),
                't_cri':    self.t_cri.state_dict(),
                'pol_opt':  self.pol_opt.state_dict(),
                'cri_opt':  self.cri_opt.state_dict()}

    def load_params(self, params):
        self.s_pol.load_state_dict(params['s_pol'])
        self.s_cri.load_state_dict(params['s_cri'])
        self.t_pol.load_state_dict(params['t_pol'])
        self.t_cri.load_state_dict(params['t_cri'])
        self.pol_opt.load_state_dict(params['pol_opt'])
        self.cri_opt.load_state_dict(params['cri_opt'])


class MARL(object):
    def __init__(self, agent_init_params, algo_type, n_agent, gamma=0.95,
                 tau=0.01, lr=0.01, hidden_dim=64, discrete_action=False):
        self.n_agent = n_agent
        self.agents = [Agent(lr=lr, discrete_action=discrete_action, hidden_dim=hidden_dim, **params)
                       for params in agent_init_params]
        self.algo_type = algo_type
        self.agent_init_params = agent_init_params
        self.gamma, self.n_iter = gamma, 0
        self.tau, self.lr, self.discrete_action = tau, lr, discrete_action
        self.prep_training()
        self.prep_rollout()

    @property
    def all_s_pol(self):
        return [a_j.s_pol for a_j in self.agents]

    @property
    def all_t_pol(self):
        return [a_j.t_pol for a_j in self.agents]

    def scale_noise(self, scale):
        for a_j in self.agents:
            a_j.scale_noise(scale)

    def step(self, obs, explore=False):
        return [a_j.step(obs_j, explore=explore) for a_j, obs_j in zip(self.agents, obs)]

    def update(self, sample, idx_a, logger=None):
        if self.algo_type[idx_a] == 'dec':
            self.dec_update(sample, idx_a, logger=logger)
        else:
            self.cen_update(sample, idx_a, logger=logger)

    # MADDPG implementation
    def cen_update(self, sample, idx_a, logger=None):
        obs, act, rew, n_obs, don = sample
        curr_a = self.agents[idx_a]
        curr_a.cri_opt.zero_grad()

        ## Value function updates
        all_t_act = [onehot_from_logits(pi_j(n_obs_j)) for pi_j, n_obs_j in
                        zip(self.all_t_pol, n_obs)]
        # Input for the target value estimation
        t_vf_in = torch.cat((*n_obs, *all_t_act), dim=1)
        # Target value
        t_value = (rew[idx_a].view(-1, 1) +
                   self.gamma * curr_a.t_cri(t_vf_in) * (1 - don[idx_a].view(-1, 1)))
        # Input for the value estimation
        s_vf_in = torch.cat((*obs, *act), dim=1)
        # Value
        s_value = curr_a.s_cri(s_vf_in)
        # MSE loss
        vf_loss = MSELoss(s_value, t_value.detach())
        vf_loss.backward()
        torch.nn.utils.clip_grad_norm_(curr_a.s_cri.parameters(), 0.5)
        curr_a.cri_opt.step()

        ## Policy updates
        curr_a.pol_opt.zero_grad()
        # Input for the policy
        curr_pol_out = curr_a.s_pol(obs[idx_a])
        curr_pol_vf_in = gumbel_softmax(curr_pol_out, 'cuda', hard=True)
        all_s_act = [onehot_from_logits(pi_j(obs_j)) for pi_j, obs_j in
                        zip(self.all_s_pol, obs)]
        all_s_act[idx_a] = curr_pol_vf_in

        s_vf_in = torch.cat((*obs, *all_s_act), dim=1)
        # Policy loss
        pol_loss = -curr_a.s_cri(s_vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        torch.nn.utils.clip_grad_norm_(curr_a.s_pol.parameters(), 0.5)
        curr_a.pol_opt.step()

        if logger is not None:
            logger.add_scalars('a_%i/L'%idx_a, {'vf_loss':vf_loss, 'pol_loss':pol_loss}, self.n_iter)

    # DDPG implementation
    def dec_update(self, sample, idx_a, logger=None):
        obs, act, rew, n_obs, don = sample
        obs_i, act_i, rew_i, n_obs_i, don_i = (obs[idx_a], act[idx_a], rew[idx_a],
                                               n_obs[idx_a], don[idx_a])
        curr_a = self.agents[idx_a]
        s_pol_i, t_pol_i = curr_a.s_pol, curr_a.t_pol
        s_cri_i, t_cri_i = curr_a.s_cri, curr_a.t_cri
        pol_opt_i, cri_opt_i = curr_a.pol_opt, curr_a.cri_opt

        ## Value function updates
        cri_opt_i.zero_grad()
        # Target value estimation
        t_vf_in = torch.cat((n_obs_i, onehot_from_logits(t_pol_i(n_obs_i))), dim=1)
        t_value = (rew_i.view(-1, 1) +
                   self.gamma * t_cri_i(t_vf_in) * (1 - don_i.view(-1, 1)))
        # Value estimation
        s_vf_in = torch.cat((obs_i, act_i), dim=1)
        s_value = s_cri_i(s_vf_in)
        vf_loss = MSELoss(s_value, t_value.detach())
        vf_loss.backward()
        torch.nn.utils.clip_grad_norm_(s_cri_i.parameters(), 0.5)
        cri_opt_i.step()

        ## Policy updates
        pol_opt_i.zero_grad()
        pol_out_i = s_pol_i(obs_i)
        pol_vf_in_i = gumbel_softmax(s_pol_i(obs_i), 'cuda', hard=True)
        s_vf_in = torch.cat((obs_i, pol_vf_in_i), dim=1)
        # Policy loss
        pol_loss = -s_cri_i(s_vf_in).mean()
        pol_loss += (pol_out_i**2).mean() * 1e-3
        pol_loss.backward()
        torch.nn.utils.clip_grad_norm_(s_pol_i.parameters(), 0.5)
        pol_opt_i.step()

        if logger is not None:
            logger.add_scalars('a_%i/L'%idx_a, {'vf_loss':vf_loss, 'pol_loss':pol_loss}, self.n_iter)

    def update_all_targets(self):
        for a_j in self.agents:
            soft_update(a_j.t_pol, a_j.s_pol, self.tau)
            soft_update(a_j.t_cri, a_j.s_cri, self.tau)
        self.n_iter += 1

    def prep_training(self, device='cuda'):
        def fn(x): return x.to(device)
        for a_j in self.agents:
            a_j.s_pol.train(); a_j.t_pol.train()
            a_j.s_cri.train(); a_j.t_cri.train()
            a_j.s_pol, a_j.t_pol = fn(a_j.s_pol), fn(a_j.t_pol)
            a_j.s_cri, a_j.t_cri = fn(a_j.s_cri), fn(a_j.t_cri)

    def prep_rollout(self, device='cpu'):
        def fn(x): return x.to(device)
        for a_j in self.agents:
            a_j.s_pol.eval(); a_j.s_pol = fn(a_j.s_pol)

    def post_rollout(self, device='cuda'):
        def fn(x): return x.to(device)
        for a_j in self.agents:
            a_j.s_pol.train(); a_j.s_pol = fn(a_j.s_pol)

    def save(self, filename):
        self.prep_training(device='cpu')
        save_dict = {'init_dict': self.init_dict,
                     'agent_params': [a_j.get_params() for a_j in self.agents]}
        torch.save(save_dict, filename)
        self.prep_training(device='cuda')

    @classmethod
    def init_from_env(cls, env, agent_algo='cen', adv_algo='cen', gamma=0.95, tau=0.01, lr=0.01, hidden_dim=64):
        agent_init_params, algo_type = [], []
        for agent in env.agents:
            if hasattr(agent, 'adversary') and agent.adversary:
                algo_type.append(adv_algo)
            elif hasattr(agent, 'adversary') and (not agent.adversary):
                algo_type.append(agent_algo)
            else:
                algo_type.append(agent_algo)
        n_agent = 0
        for k, (acsp, obsp) in enumerate(zip(env.action_space, env.observation_space)):
            n_agent += 1
            n_p_i = obsp.shape[0]
            if isinstance(acsp, Box):
                def get_shape(x): return x.shape[0]
                discrete_action = False
            else:
                def get_shape(x): return x.n
                discrete_action = True
            n_p_o = get_shape(acsp)
            n_c_i_d, n_c_i_c = 0, 0
            n_c_i_d = obsp.shape[0] + get_shape(acsp)
            for oobsp in env.observation_space:
                n_c_i_c += oobsp.shape[0]
            for oacsp in env.action_space:
                n_c_i_c += get_shape(oacsp)
            if algo_type[k] == 'cen':
                n_c_i = n_c_i_c
            else:
                n_c_i = n_c_i_d
            agent_init_params.append({'n_p_i': n_p_i,
                                      'n_p_o': n_p_o,
                                      'n_c_i': n_c_i})
        init_dict = {'gamma': gamma, 'tau': tau, 'lr': lr, 'algo_type': algo_type,
                     'hidden_dim': hidden_dim, 'agent_init_params': agent_init_params,
                     'n_agent': n_agent, 'discrete_action': discrete_action}
        instance = cls(**init_dict)
        instance.init_dict = init_dict
        return instance

    @classmethod
    def init_from_save(cls, filename):
        save_dict = torch.load(filename)
        instance = cls(**save_dict['init_dict'])
        instance.init_dict = save_dict['init_dict']
        for a_j, params in zip(instance.agents, save_dict['agent_params']):
            a_j.load_params(params)



