import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from demo.unet import UNetMLP
from demo.helper import at_least_ndim, linear_beta_schedule, cosine_beta_schedule


class DDPM(nn.Module):
    def __init__(self, args, x_max=None, x_min=None):
        super(DDPM, self).__init__()
        # params
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.diffusion_steps = args.model.sample_steps
        self.beta_schedule = args.model.beta_schedule

        self.player_type = args.player_type
        self.player_idx = args.player_type_idx
        self.player_list = list(args.player_list)
        self.other_agent_list = [item for item in self.player_list if item != self.player_type]

        if self.env in ['tennis', 'box', 'connect4']:
            self.state_dim_list = [args.model.latent_dim] * len(self.player_list)
        else:
            self.state_dim_list = list(args.state_dim_list)
        
        self.x_max, self.x_min = x_max, x_min
        self.ema_rate = args.train.ema_rate
        self.cond_dim = args.model.cond_dim
        self.state_dim = self.state_dim_list[self.player_idx]
        
        # model
        self.model = nn.ModuleDict({
            name: UNetMLP(args, state_dim=self.state_dim #if name == self.player_type else self.state_dim * 2
                          ).to(self.device) 
            for name in self.player_list
        })
        
        #self.Condition = nn_condition
        self.model_ema = deepcopy(self.model).requires_grad_(False)
        self.model.train()
        self.model_ema.eval()
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.train.lr, weight_decay=args.train.weight_decay)

        # beta
        beta_schedule_params = {}
        beta_schedule_params["T"] = self.diffusion_steps

        if self.beta_schedule == "linear":
            beta = linear_beta_schedule(**beta_schedule_params)
        elif self.beta_schedule == "cosine":
            beta = cosine_beta_schedule(**beta_schedule_params)
        else:
            raise ValueError(f"Unknown beta schedule: {self.beta_schedule}")
        
        self.beta = torch.tensor(beta, device=self.device, dtype=torch.float32)
        self.alpha = 1 - self.beta
        self.bar_alpha = torch.cumprod(self.alpha.clone(), 0)

    # ---------------------------------------------------------------------------
    # Save weight
    def save_weights(self, path):
        torch.save({
            "model": self.model.state_dict(),
            "model_ema": self.model_ema.state_dict(),
        }, path)
        
    # Load weight
    def load_weights(self, path=None, strict=True):
        try:
            checkpoint = torch.load(path, map_location=self.device, weights_only=True)
            self.model.load_state_dict(checkpoint["model"], strict=strict)
            self.model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
            #self.load_state_dict(checkpoint, strict=strict)
            print(f"Loaded Diffusion model weights from {path}")
            self.model.eval()
            
        except Exception as e:
            raise ValueError(f"Failed to load diffusion model: {e}")
    
    # EMA
    def ema_update(self):
        with torch.no_grad():
            for p, p_ema in zip(self.model.parameters(), self.model_ema.parameters()):
                p_ema.data.mul_(self.ema_rate).add_(p.data, alpha=1. - self.ema_rate)
        
    # ---------------------------------------------------------------------------
    # Training
    def add_noise(self, x0, t=None, eps=None):
        first_x = next(iter(x0.values()))
        t = torch.randint(self.diffusion_steps, (first_x.shape[0],), device=self.device) if t is None else t
        eps = torch.randn_like(first_x) if eps is None else eps
        xts = {}
        for name in x0:
            x = x0[name]
            if x.dim() == 1:
                x = x.unsqueeze(0)
            bar_alpha = at_least_ndim(self.bar_alpha[t], x.dim())
            xt = x * bar_alpha.sqrt() + eps * (1 - bar_alpha).sqrt()
            xts[name] = xt
        return xts, t, eps
    


    def forward(self, x0, condition=None):
        xts, t, eps = self.add_noise(x0)
        loss = self.update(xts, t, eps, condition)
        
        return loss
    
    
    def update(self, xts, t, eps, condition, update_ema=True):
        pred1 = self.model[self.player_type](xts[self.player_type], t, cond = condition['state'])
        pred2 = [
            self.model[name](xts[name], t, cond=condition['state']) # torch.cat([condition['prev_state'], condition['state']], dim=-1))
            for name in self.other_agent_list
        ]
        
        loss1 = ((pred1 - eps) ** 2).mean()
        loss2 = sum([((pred - eps) ** 2).mean() for pred in pred2])
        loss3 = sum([((pred - pred1) ** 2).mean() for pred in pred2])
        
        loss = loss1 + loss2 + 0.1 * loss3
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if update_ema: 
            self.ema_update()
            
        return loss.item()


    # ---------------------------------------------------------------------------
    # Inference
    @property
    def clip_pred(self):
        return (self.x_max is not None) or (self.x_min is not None)
        
        
    @torch.no_grad()
    def predict_function(self, x, t, bar_alpha, condition=None, use_ema=True, d_weight=1.0, alpha=0.5, ddgi=True):
        model = self.model_ema if use_ema else self.model
        
        if ddgi:
            eps_agent = model[self.player_type](x, t, cond = condition['state'])
            eps_agent_none = self.model[self.player_type](x, t, cond = None)
            
            eps_adv = []
            for name in self.other_agent_list:
                eps_adv.append(model[name](x, t, cond = condition['state'])) # torch.cat([condition['prev_state'], condition['state']], dim=-1))
                
            eps_adv = torch.stack(eps_adv, dim=0).mean(dim=0)
            d = (eps_adv - eps_agent).norm(p=2)
            lam = torch.sigmoid(d_weight * d)
            pred = lam * eps_agent + (1 - lam) * eps_agent_none #eps_adv
        
        else:
            eps_agent = model[self.player_type](x, t, cond = condition['state'])
            eps_agent_none = self.model[self.player_type](x, t, cond = None)
            lam = torch.tensor(0.0).to(self.device)
            pred = lam * eps_agent + (1 - lam) * eps_agent_none # 0.0, 0.25, 0.5, 0.75 # push, tag, spread, connect4, holdem
            
        
        # ------------------------------------------- #
        if self.clip_pred:
            upper_bound = (x - bar_alpha.sqrt() * self.x_min) / (1 - bar_alpha).sqrt() \
                if self.x_min is not None else None
            lower_bound = (x - bar_alpha.sqrt() * self.x_max) / (1 - bar_alpha).sqrt() \
                if self.x_max is not None else None
            pred = pred.clip(lower_bound, upper_bound)

        return pred, lam.item()


    @torch.no_grad()
    def sample(self, prior=None, n_samples=1, sample_steps=None, extra_sample_steps=0, 
               temperature=1.0, d_weight=5.0, sample_alpha=0.5, condition=None, ddgi=True):
        # check `sample_steps`
        if sample_steps != self.diffusion_steps:
            sample_steps = self.diffusion_steps

        # initialize the samples
        xt = torch.randn_like(prior, device=self.device) * temperature
        lam_list = []
        # enter the sampling loop
        for t in range(self.diffusion_steps - 1, -1, -1):
            t_batch = torch.tensor(t, device=self.device, dtype=torch.long).repeat(n_samples)
            bar_alpha = self.bar_alpha[t]
            bar_alpha_prev = self.bar_alpha[t - 1] if t > 0 else torch.tensor(1.0, device=self.device)
            alpha = self.alpha[t]
            beta = self.beta[t]

            # predict eps_theta or x_theta with CFG
            pred_theta, lam = self.predict_function(xt, t_batch, bar_alpha, condition=condition, d_weight=d_weight, alpha=sample_alpha, ddgi=ddgi)
            xt = 1 / alpha.sqrt() * (xt - beta / (1 - bar_alpha).sqrt() * pred_theta)
            lam_list.append(lam)

            if t != 0:
                xt = xt + (beta * (1 - bar_alpha_prev) / (1 - bar_alpha)).sqrt() * torch.randn_like(xt)


        if extra_sample_steps > 0:
            t_batch = torch.tensor(0, device=self.device, dtype=torch.long).repeat(n_samples)
            bar_alpha = self.bar_alpha[0]
            bar_alpha_prev = torch.tensor(1.0, device=self.device)
            alpha = self.alpha[0]
            beta = self.beta[0]

            for _ in range(extra_sample_steps):
                pred_theta = self.predict_function(xt, t_batch, bar_alpha, condition=condition)
                # one step denoise
                xt = 1 / alpha.sqrt() * (xt - beta / (1 - bar_alpha).sqrt() * pred_theta)
        #1/0
        return xt, lam_list