# Ref: https://github.com/NTURobotLearningLab/dbc

import sys
import os
import warnings
warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./model'))
sys.path.append(os.path.abspath('./model/baseline/DBC2'))

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os, sys
import argparse
import numpy as np
from model.emb.emb import StateEmb
from model.utils.agent_utils import *
from utils.dataloader import SADataLoader
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, input_dim=23, num_units=1024, depth=4, device="cuda"):
        super(MLPDiffusion, self).__init__()
        linears_list = []
        linears_list.append(nn.Linear(input_dim, num_units))
        linears_list.append(nn.ReLU())
        if depth > 1:
            for i in range(depth - 1):
                linears_list.append(nn.Linear(num_units, num_units))
                linears_list.append(nn.ReLU())
        linears_list.append(nn.Linear(num_units, input_dim))
        self.linears = nn.ModuleList(linears_list).to(device)

        embed_list = []
        for i in range(depth - 1):
            embed_list.append(nn.Embedding(n_steps, num_units))
        self.step_embeddings = nn.ModuleList(embed_list).to(device)

    def forward(self, x, t):
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)
        x = self.linears[-1](x)
        return x


########### training loss funciton,  sample at any given time t, and calculate sampling loss
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    batch_size = x_0.shape[0]
    t = torch.randint(0, n_steps, size=(batch_size // 2,)).to(device)
    t = torch.cat([t, n_steps - 1 - t], dim=0)  # [batch_size, 1]
    if batch_size % 2 == 1:
        extra = torch.randint(0, n_steps, size=(1,), device=device)
        t = torch.cat([t, extra], dim=0)
    t = t.unsqueeze(-1)
    
    a = alphas_bar_sqrt[t]
    aml = one_minus_alphas_bar_sqrt[t]
    e = torch.randn_like(x_0)

    x = x_0 * a + e * aml
    if x.dtype == torch.float64:
        x = x.to(torch.float32)
    output = model(x, t.squeeze(-1))
    
    return (e - output).square().mean()


########### reverse diffusion sample function（inference）
def p_sample_loop(model, shape, n_steps, betas, one_minus_alphas_bar_sqrt):
    # generate[T-1]、x[T-2]|...x[0] from x[T]
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq


def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt):
    # sample reconstruction data at time t drom x[T]
    t = torch.tensor([t]).to(device)
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    eps_theta = model(x, t)
    mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta))
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    sample = mean + sigma_t * z
    return sample


def reconstruct(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    # generate random t for a batch data
    t = torch.ones_like(x_0, dtype=torch.long).to(device) * n_steps
    # coefficient of x0
    a = alphas_bar_sqrt[t]
    # coefficient of eps
    aml = one_minus_alphas_bar_sqrt[t]
    # generate random noise eps
    e = torch.randn_like(x_0).to(device)
    betas = torch.clip(betas, 0.0001, 0.9999).to(device)
    # model input
    x_T = x_0 * a + e * aml
    if x_T.dtype == torch.float64:
        x_T = x_T.to(torch.float32)
    # generate[T-1]、x[T-2]|...x[0] from x[T]
    for i in reversed(range(n_steps)):
        # for i in reversed(range(1, n_steps+1)):
        x_T = p_sample(model, x_T, i, betas, one_minus_alphas_bar_sqrt)
    x_construct = x_T
    return x_construct



def DDPMTrainer(args, datalen):
    #hyper parameter
    #device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    batch_size = args.ddpm.batch_size
    num_epoch = args.ddpm.num_epoch
    num_steps = args.ddpm.num_steps
    betas = sigmoid_beta_schedule(num_steps)
    
    # --------------- Emb Model -----------------    
    if args.env_name in ['tennis', 'box', 'connect4']:
        emb_path = f'{args.emb_path}/{args.env_name}/emb.pth'
        emb_model = StateEmb(args).to(args.device)
        checkpoint = torch.load(emb_path, weights_only=True)
        emb_model.load_state_dict(checkpoint)
        emb_model.eval()

    model_save_path = f'./model/baseline/DBC2/_weight/{args.env_name}/'
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path, exist_ok=True)
    
    betas = torch.clip(betas, 0.0001, 0.9999).to(device)

    # calculate alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt
    alphas = 1 - betas
    alphas_prod = torch.cumprod(alphas, 0)
    alphas_prod_p = torch.cat(
        [torch.tensor([1]).float().to(device), alphas_prod[:-1]], 0
    )
    alphas_bar_sqrt = torch.sqrt(alphas_prod)
    one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
    one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)

    assert (
        alphas.shape
        == alphas_prod.shape
        == alphas_prod_p.shape
        == alphas_bar_sqrt.shape
        == one_minus_alphas_bar_log.shape
        == one_minus_alphas_bar_sqrt.shape
    )
    print("all the same shape", betas.shape)

    if args.env_name == 'badminton':
        path = f'{args.data_path}/badminton/{args.player_name}_dataset.csv'
        player = args.player_name
    else:
        path = f'{args.data_path}/{args.env_name}/{datalen}.pkl'
        player = None
    dl = SADataLoader(env_name = args.env_name, 
                      pkl_path = path, 
                      player_name = player,
                      batch_size = batch_size)
    loader = dl.get_dataloader()
    train_loader = loader['train']
    print("Training model...")
    
    
    # state dim list
    if args.env_name in ['tennis', 'box', 'connect4']:
        state_dim_list = [args.model.latent_dim] * len(args.player_list)
    else:
        state_dim_list = list(args.state_dim_list)
    
    # action dim list
    if args.env_name in ['badminton']:
        action_dim_list = list(args.action_dim_list)
    else:
        action_dim_list = [1] * args.adv_num + [1] * args.agent_num
    
    input_dim_list = [x + args.action_dim for x in state_dim_list]
    print('input_dim_list:', input_dim_list)
        
    # output dimension is state_dim + action_dim，inputs are x and step
    model = nn.ModuleDict({
            name: MLPDiffusion(
            num_steps,
            input_dim=input_dim_list[idx],
            num_units=args.ddpm.hidden_dim,
            depth=args.ddpm.depth,
        ).to(device)
            for idx, name in enumerate(args.player_list)
        })
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.ddpm.lr)
    train_loss_list = []
    
    for t in tqdm(range(0, num_epoch + 1)):
        total_loss = 0
        for idx, batch_x in enumerate(train_loader):
            states, actions = batch_x
            
            if args.env_name in ['tennis', 'box']:
                states = torch.unbind(states, dim=1)
            elif args.env_name in ['badminton', 'connect4', 'holdem']:
                states = states
            else:
                states = torch.split(states, state_dim_list, dim=1)
            actions = torch.split(actions, action_dim_list, dim=1)
            
            losses = 0.0
            for idx, name in enumerate(args.player_list):
                action = actions[idx] if args.env_name == 'badminton' else action_one_hot(actions[idx], num_classes = args.action_dim).to(device)
                action = action.squeeze(1) if action.dim() > 2 else action
                action = action.to(device)
                #action = action.view(action.shape[0], -1).unsqueeze(1) if args.agent_num > 1 else action.unsqueeze(1)
                
                if args.env_name == 'connect4':
                    state = emb_model.state_embed(states, name).to(device)
                    
                elif args.env_name in ['badminton', 'holdem']:
                    state = states.to(device)
                    
                elif args.env_name in ['tennis', 'box']:
                    state = emb_model.state_embed(states[idx], name).to(device)
                else:
                    state = states[idx].to(device)
                    
                batch = torch.cat([state, action], dim=-1).to(device)
                loss = diffusion_loss_fn(
                    model[name], batch, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps
                )
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model[name].parameters(), 0.5)
                optimizer.step()
                loss = loss.cpu().detach()
                losses += loss
            
            total_loss += losses
            
        ave_loss = total_loss / len(train_loader.dataset)
        train_loss_list.append(ave_loss) 
          
        if t % 100 == 0:
            for name, m in model.items():
                if args.env_name == 'badminton':
                    torch.save(m.state_dict(), f'{model_save_path}/{name}_{args.player_name}_ddpm.pth')
                else:
                    torch.save(m.state_dict(), f'{model_save_path}/{name}_ddpm.pth')