import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from .helpers import (cosine_beta_schedule,
                            linear_beta_schedule,
                            vp_beta_schedule,
                            extract,
                            Losses)
from .helpers import SinusoidalPosEmb, init_weights
from utils.transformer import Transformer
from utils.embed import polynomial_embed, binary_embed

# from agent.model import Model
class Model(nn.Module):
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, args, is_vqvae=False, num_embeddings=5):
        super(Model, self).__init__()

        self.args = args
        self.task2last_action_shape = {task: task2input_shape_info[task]["last_action_shape"] for task in
                                       task2input_shape_info}
        
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents

        if is_vqvae:
            self.latent_dim = num_embeddings
        else:
            self.latent_dim = args.latent_dim
        self.hidden_size = 256
        time_dim = 32

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, self.hidden_size),
            nn.Mish(),
            nn.Linear(self.hidden_size, time_dim),
        )

        self.entity_embed_dim = args.policy_entity_embed_dim
        self.attn_embed_dim = args.attn_embed_dim

        task_0 = list(task2decomposer.keys())[0]
        obs_own_dim = task2decomposer[task_0].own_obs_dim
        obs_en_dim, obs_al_dim = task2decomposer[task_0].obs_nf_en, task2decomposer[task_0].obs_nf_al
        n_actions_no_attack = task2decomposer[task_0].n_actions_no_attack
        
        has_attack_action = n_actions_no_attack != task2decomposer[task_0].n_actions

        if args.obs_agent_id and args.obs_last_action:
            if has_attack_action:
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack + 1
                obs_en_dim += 1
            else:
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack
        else:
            wrapped_obs_own_dim = obs_own_dim

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim)
        
        if getattr(args, "use_role_encoder", True) and not getattr(args, "only_role_encoding", False):
                self.encoding_value = nn.Linear(2 * args.encoding_dim, self.entity_embed_dim)
        else:
            self.encoding_value = nn.Linear(args.encoding_dim, self.entity_embed_dim)

        max_ally_num = args.max_ally_num
        self.ally_time_embed = nn.Embedding(max_ally_num, self.entity_embed_dim)

        max_enemy_num = args.max_enemy_num
        self.enemy_time_embed = nn.Embedding(max_enemy_num, self.entity_embed_dim)

        self.transformer = Transformer(self.entity_embed_dim, args.policy_head, args.policy_depth, self.entity_embed_dim)

        if args.use_encoding:
            if not args.simple_mlp_agent:
                input_dim = 5 * self.entity_embed_dim + self.latent_dim + time_dim
            else:
                input_dim = 5 * self.entity_embed_dim
        else:
            if not args.simple_mlp_agent:
                input_dim = 4 * self.entity_embed_dim + self.latent_dim + time_dim
            else:
                input_dim = 4 * self.entity_embed_dim

        self.layer = nn.Sequential(nn.Linear(input_dim, self.hidden_size),
                                       nn.Mish(),
                                       nn.Linear(self.hidden_size, self.hidden_size),
                                       nn.Mish(),
                                       nn.Linear(self.hidden_size, self.hidden_size),
                                       nn.Mish(),
                                       nn.Linear(self.hidden_size, self.latent_dim))
        
        self.fc1 = nn.Linear(input_dim, self.hidden_size)

        self.apply(init_weights)
        

    def forward(self, x, time, inputs, hidden_state, task, task_encoding):
        hidden_state = hidden_state.view(-1, 1, self.entity_embed_dim)
        task_decomposer = self.task2decomposer[task]
        task_n_agents = self.task2n_agents[task]
        last_action_shape = self.task2last_action_shape[task]

        # decompose inputs into observation inputs, last_action_info, agent_id_info
        obs_dim = task_decomposer.obs_dim
        obs_inputs, last_action_inputs, agent_id_inputs = inputs[:, :obs_dim], \
                                                          inputs[:, obs_dim:obs_dim + last_action_shape], inputs[:,
                                                                                                          obs_dim + last_action_shape:]

        # decompose observation input
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  # own_obs: [bs*self.n_agents, own_obs_dim]
        bs = int(own_obs.shape[0] / task_n_agents)

        # embed agent_id inputs and decompose last_action_inputs
        agent_id_inputs = [
            torch.as_tensor(binary_embed(i + 1, self.args.id_length, self.args.max_agent), dtype=own_obs.dtype) for i in
            range(task_n_agents)]
        agent_id_inputs = torch.stack(agent_id_inputs, dim=0).repeat(bs, 1).to(own_obs.device)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)

        # incorporate agent_id embed and compact_action_states
        if self.args.obs_last_action and self.args.obs_agent_id:
            # if obs_last_action and obs_agent_id, then own_obs should be wrapped
            own_obs = torch.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        else:
            own_obs = own_obs
        # own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        
        # incorporate attack_action_info into enemy_feats
        if np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = torch.cat([torch.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = torch.stack(enemy_feats, dim=0)
        ally_feats = torch.stack(ally_feats, dim=0)

        # compute key, query and value for attention
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)
        encoding_hidden = self.encoding_value(task_encoding).unsqueeze(1)
        # skill_hidden = self.skill_value(skill).unsqueeze(1)
        history_hidden = hidden_state

        # ally time embedding
        bs, ally_seq_len, _ = ally_hidden.shape
        ally_steps = torch.arange(ally_seq_len, device=ally_hidden.device).long()  # (seq_len,)
        ally_step_emb = ally_steps.view(1, ally_seq_len).expand(bs, -1)
        ally_step_emb = self.ally_time_embed(ally_step_emb)  # (bs, seq_len, entity_embed_dim)
        ally_hidden = ally_hidden + ally_step_emb
        # enemy time embedding
        bs, enemy_seq_len, _ = enemy_hidden.shape
        enemy_steps = torch.arange(enemy_seq_len, device=enemy_hidden.device).long()  # (seq_len,)
        enemy_step_emb = enemy_steps.view(1, enemy_seq_len).expand(bs, -1)
        enemy_step_emb = self.enemy_time_embed(enemy_step_emb)  # (bs, seq_len, entity_embed_dim)
        enemy_hidden = enemy_hidden + enemy_step_emb

        if self.args.use_encoding:
            if self.args.use_hidden:
                total_hidden = torch.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden, history_hidden], dim=1)
            else:
                total_hidden = torch.cat([encoding_hidden, own_hidden, enemy_hidden, ally_hidden], dim=1)
            
            outputs = self.transformer(total_hidden, None)

            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            h = outputs[:, -1:, :]
            encoding_inputs = outputs[:, 0, :]
            base_action_inputs = outputs[:, 1, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
            obs_enemy = torch.max(outputs[:,2:2+enemy_length,:], dim=1)[0]
            obs_ally = torch.max(outputs[:,2+enemy_length:2+enemy_length+ally_length,:], dim=1)[0]
            obs_out = torch.cat([encoding_inputs, base_action_inputs, obs_enemy, obs_ally], dim=-1)
        else:
            if self.args.use_hidden:
                total_hidden = torch.cat([own_hidden, enemy_hidden, ally_hidden, history_hidden], dim=1)
            else:
                total_hidden = torch.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)
            
            outputs = self.transformer(total_hidden, None)

            own_length = 1
            enemy_length = enemy_hidden.shape[1]
            ally_length = ally_hidden.shape[1]

            h = outputs[:, -1:, :]
            # encoding_inputs = outputs[:, 0, :]
            base_action_inputs = outputs[:, 0, :]  # th.cat([outputs[:, 0, :], skill], dim=-1)
            obs_enemy = torch.max(outputs[:,1:1+enemy_length,:], dim=1)[0]
            obs_ally = torch.max(outputs[:,1+enemy_length:1+enemy_length+ally_length,:], dim=1)[0]
            obs_out = torch.cat([base_action_inputs, obs_enemy, obs_ally], dim=-1)

        if not self.args.simple_mlp_agent:
            t = self.time_mlp(time)
            out = torch.cat([x, t, obs_out, encoding_hidden.squeeze(1)], dim=-1)
        else:
            out = torch.cat([obs_out, encoding_hidden.squeeze(1)], dim=-1)
        out = self.layer(out)
        
        return out, h


class MtDiffusion(nn.Module):
    def __init__(self, noise_ratio, task2input_shape_info, task2decomposer, task2n_agents, args,
                 beta_schedule='vp', n_timesteps=1000,
                 loss_type='l2', clip_denoised=True, predict_epsilon=True,
                 behavior_sample=16, eval_sample=512, deterministic=False,
                 is_vqvae=False, num_embeddings=5):
        super(MtDiffusion, self).__init__()

        self.task2input_shape_info = task2input_shape_info
        self.task2decomposer = task2decomposer
        self.task2n_agents = task2n_agents
        self.args = args

        if not is_vqvae:
            self.action_dim = args.latent_dim
        else:
            self.action_dim = num_embeddings
        self.model = Model(task2input_shape_info, task2decomposer, task2n_agents, args, is_vqvae=is_vqvae, num_embeddings=num_embeddings)
        self.hidden_size = self.model.hidden_size

        self.max_noise_ratio = noise_ratio
        self.noise_ratio = noise_ratio

        self.behavior_sample = behavior_sample
        self.eval_sample = eval_sample
        self.deterministic = deterministic

        if beta_schedule == 'linear':
            betas = linear_beta_schedule(n_timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(n_timesteps)
        elif beta_schedule == 'vp':
            betas = vp_beta_schedule(n_timesteps)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])

        self.n_timesteps = int(n_timesteps)
        self.clip_denoised = clip_denoised
        self.predict_epsilon = predict_epsilon

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        ## log calculation clipped because the posterior variance
        ## is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1',
                             betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))

        self.loss_fn = Losses[loss_type]()
    
    def init_hidden(self):
        # make hidden states on same device as model
        return self.model.fc1.weight.new(1, self.model.entity_embed_dim).zero_()

    # ------------------------------------------ sampling ------------------------------------------#

    def predict_start_from_noise(self, x_t, t, noise):
        '''
            if self.predict_epsilon, model output is (scaled) noise;
            otherwise, model predicts x0 directly
        '''
        if self.predict_epsilon:
            return (
                    extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                    extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
            )
        else:
            return noise

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, s, hidden_state, task, task_encoding):
        noise, new_hidden_state = self.model(x, t, s, hidden_state, task, task_encoding)
        x_recon = self.predict_start_from_noise(x, t=t, noise=noise)

        if self.clip_denoised:
            x_recon.clamp_(-1., 1.)
        else:
            assert RuntimeError()

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance, new_hidden_state

    @torch.no_grad()
    def p_sample(self, x, t, s, hidden_state, task, task_encoding):
        b, *_, device = *x.shape, x.device

        model_mean, _, model_log_variance, new_hidden_state = self.p_mean_variance(x=x, t=t, s=s, hidden_state=hidden_state, task=task, task_encoding=task_encoding)

        noise = torch.randn_like(x)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise * self.noise_ratio, new_hidden_state


    @torch.no_grad()
    def p_sample_loop(self, state, shape, hidden_state, task, task_encoding):
        device = self.betas.device

        batch_size = shape[0]
        x = torch.randn(shape, device=device)

        for i in reversed(range(0, self.n_timesteps)):
            timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
            # 不管多少步，new_hidden_state应当都是一样的，其只与state有关
            x, new_hidden_state = self.p_sample(x, timesteps, state, hidden_state, task, task_encoding)

        return x, new_hidden_state

    @torch.no_grad()
    def sample(self, state, hidden_state, task, task_encoding, eval=False, q_func=None, normal=False):
        if self.args.simple_mlp_agent:
            action, new_hidden_state = self.model(None, None, state, hidden_state, task, task_encoding)
            return action, new_hidden_state
        if self.deterministic:
            self.noise_ratio = 0 if eval else self.max_noise_ratio
        else:
            self.noise_ratio = self.max_noise_ratio

        if normal:
            batch_size = state.shape[0]
            shape = (batch_size, self.action_dim)
            action, new_hidden_state = self.p_sample_loop(state, shape, hidden_state, task, task_encoding)
            return action, new_hidden_state

        if eval:
            raw_batch_size = state.shape[0]
            # print(state.shape)
            state = state.repeat(self.eval_sample, 1)
            batch_size = state.shape[0]
            shape = (batch_size, self.action_dim)
            if hidden_state is not None:
                # n_agents = hidden_state.shape[1]
                if len(hidden_state.shape)==3:
                    hidden_dim = hidden_state.shape[-1]
                    hidden_state = hidden_state.reshape(-1, hidden_dim)
                tmp_hidden_state = hidden_state.repeat(self.eval_sample, 1)
            else:
                tmp_hidden_state = None
            # print(hidden_state.shape)
            # print(tmp_hidden_state.shape)
            task_encoding = task_encoding.repeat(self.eval_sample, 1)
            action, new_hidden_state = self.p_sample_loop(state, shape, tmp_hidden_state, task, task_encoding) 
            
            q = q_func(state, action, task, task_encoding)
            action = action.view(self.eval_sample, raw_batch_size, -1).transpose(0,1)
            if new_hidden_state is not None:
                # print(self.eval_sample, raw_batch_size, n_agents, new_hidden_state.shape)
                # assert False
                new_hidden_state = new_hidden_state.reshape(self.eval_sample, raw_batch_size, -1)[0]
            q = q.view(self.eval_sample, raw_batch_size, -1).transpose(0,1)
            action_idx = torch.argmax(q, dim=1, keepdim=True).repeat(1,1,self.action_dim)

            return action.gather(dim=1, index=action_idx).view(raw_batch_size, -1), new_hidden_state
        else:
            raw_batch_size = state.shape[0]
            state = state.repeat(self.behavior_sample, 1)
            batch_size = state.shape[0]
            shape = (batch_size, self.action_dim)
            if hidden_state is not None:
                if len(hidden_state.shape)==3:
                    hidden_dim = hidden_state.shape[-1]
                    hidden_state = hidden_state.reshape(-1, hidden_dim)
                tmp_hidden_state = hidden_state.repeat(self.eval_sample, 1)
            else:
                tmp_hidden_state = None
            task_encoding = task_encoding.repeat(self.eval_sample, 1)
            action, new_hidden_state = self.p_sample_loop(state, shape, tmp_hidden_state, task, task_encoding)

            q = q_func(state, action, task, task_encoding)
            action = action.view(self.behavior_sample, raw_batch_size, -1).transpose(0,1)
            if new_hidden_state is not None:
                new_hidden_state = new_hidden_state.reshape(self.eval_sample, raw_batch_size, -1)[0]
            q = q.view(self.behavior_sample, raw_batch_size, -1).transpose(0,1)
            action_idx = torch.argmax(q, dim=1, keepdim=True).repeat(1,1,self.action_dim)
            return action.gather(dim=1, index=action_idx).view(raw_batch_size, -1), new_hidden_state

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sample = (
                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

        return sample

    def p_losses(self, x_start, state, t, hidden_state, task, task_encoding, weights=1.0):

        if self.args.simple_mlp_agent:
            action, new_hidden_state = self.model(None, None, state, hidden_state, task, task_encoding)
            if getattr(self.args, "logprob_loss", False):
                log_probs = F.log_softmax(action, dim=-1)
                nll_per_sample = - (x_start * log_probs).sum(dim=-1, keepdim=True)
                loss = nll_per_sample * weights
            else:
                loss = self.loss_fn(action, x_start, weights)
            return loss, new_hidden_state

        noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        x_recon, new_hidden_state = self.model(x_noisy, t, state, hidden_state, task, task_encoding)

        assert noise.shape == x_recon.shape

        if self.predict_epsilon:
            loss = self.loss_fn(x_recon, noise, weights)
        else:
            loss = self.loss_fn(x_recon, x_start, weights)

        return loss, new_hidden_state

    def loss(self, x, state, hidden_state, task, task_encoding, weights=1.0):
        batch_size = len(x)
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
        return self.p_losses(x, state, t, hidden_state, task, task_encoding, weights)

    def forward(self, state, hidden_state, task, task_encoding, eval=False, q_func=None, normal=False):
        return self.sample(state, hidden_state, task, task_encoding, eval, q_func, normal)