import sys
import os
import warnings
warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./model'))
sys.path.append(os.path.abspath('./model/baseline/DBC'))

import torch
import torch.nn as nn
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
from cleandiffuser.diffusion import ContinuousDiffusionSDE
from cleandiffuser.invdynamic import MlpInvDynamic
from cleandiffuser.nn_condition import MLPCondition
from cleandiffuser.nn_diffusion import DiT1d
from cleandiffuser.utils import general_one_hot

from model.utils.dataloader import TrajsDataLoader
from model.emb.emb import StateEmb
from model.utils.agent_utils import *


def DDTrainer(args, datalen):
    print(f'Env name: {args.env_name}')
    print(f"==============================================================================")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_path = f'./model/baseline/DBC/_weight/dd/{args.env_name}/'
    
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)

    # ---------------------- Create Dataset ----------------------
    if args.env_name == 'badminton':
        path = f'{args.data_path}/badminton/{args.player_name}_dataset.csv'
        player = args.player_name
    else:
        path = f'{args.data_path}/{args.env_name}/{datalen}.pkl'
        player = None
        
    dl = TrajsDataLoader(env_name = args.env_name,
                         pkl_path = path,
                         player_name = player,
                         batch_size = args.batch_size, 
                         split_ratio = 1.0,
                         pad_length = args.traj_length)
    loader = dl.get_dataloader()
    train_seq = loader["train"]
    print(f'= Dataset len: {datalen} =\n')
    
    # --------------- Emb Model -----------------    
    if args.env_name in ['tennis', 'box', 'connect4']:
        emb_path = f'{args.emb_path}/{args.env_name}/emb.pth'
        emb_model = StateEmb(args).to(device)
        checkpoint = torch.load(emb_path, weights_only=True)
        emb_model.load_state_dict(checkpoint)
        emb_model.eval()
        
    # --------------- Create Diffusion Model -----------------    
    player_list = list(args.player_list)
    if args.env_name in ['badminton']:
        action_dim_list = list(args.action_dim_list)
    else:
        action_dim_list = [1] * args.adv_num + [1] * args.agent_num
    
    if args.env_name in ['tennis', 'box', 'connect4']:
        obs_dim_list = [args.model.latent_dim] * len(player_list)
    else:
        obs_dim_list = list(args.state_dim_list)
    
    # --------------- Network Architecture -----------------
    nn_diffusion = nn.ModuleDict({
        name: DiT1d(obs_dim_list[idx], 
                    emb_dim=args.model.emb_dim, 
                    d_model=320, 
                    n_heads=10, depth=2, 
                    timestep_emb_type="fourier").to(device)
        for idx, name in enumerate(player_list)
        })
    
    nn_condition = nn.ModuleDict({
        name: MLPCondition(in_dim=1, out_dim=args.model.emb_dim, 
                           hidden_dims=[args.model.emb_dim, ], 
                           act=nn.ReLU(), 
                           dropout=0.1).to(args.device)
        for idx, name in enumerate(player_list)
        })

    # ----------------- Masking -------------------
    fix_mask, loss_weight = {}, {}
    for idx, name in enumerate(player_list):
        fix_mask[name] = torch.zeros((args.model.horizon, obs_dim_list[idx]))
        fix_mask[name][0] = 1.
        loss_weight[name] = torch.ones((args.model.horizon, obs_dim_list[idx]))
        loss_weight[name][1] = args.next_obs_loss_weight
    

    # --------------- Diffusion Model with Classifier-Free Guidance --------------------
    agent = {
        name: ContinuousDiffusionSDE(
            nn_diffusion=nn_diffusion[name], 
            nn_condition=nn_condition[name], 
            fix_mask=fix_mask[name],
            loss_weight=loss_weight[name], 
            ema_rate=args.ema_rate, 
            device=device, 
            predict_noise=args.predict_noise, 
            noise_schedule="linear")
        for idx, name in enumerate(player_list)
        }

    # --------------- Inverse Dynamic -------------------
    agent_nums = args.agent_num + args.adv_num
    invdyn = {
        name: MlpInvDynamic(
            obs_dim_list[idx], 
            args.action_dim * agent_nums, 512, 
            nn.Tanh(), {"lr": 2e-4}, 
            device=device)
        for idx, name in enumerate(player_list)
        }

    # ---------------------- Training ----------------------
    diffusion_lr_scheduler, invdyn_lr_scheduler = {}, {}
    for idx, name in enumerate(player_list):
        diffusion_lr_scheduler[name] = CosineAnnealingLR(agent[name].optimizer, args.diffusion_gradient_steps)
        invdyn_lr_scheduler[name] = CosineAnnealingLR(invdyn[name].optim, args.invdyn_gradient_steps)
    
    

    n_gradient_step = 0
    log = {"avg_loss_diffusion": 0.,  "avg_loss_invdyn": 0.}

    while True:
        for batch in train_seq:
            obss = batch['states'].to(device)
            acts = batch['actions'].to(device)
            #val = batch["val"].to(args.device)# / return_scale + 1.  # rescale to [0, 1]
            
            if args.env_name in ['tennis', 'box']:
                obs = torch.unbind(obss, dim=2) # [b, len, 2, 84, 84, 6]
            elif args.env_name in ['badminton', 'connect4', 'holdem']:
                obs = obss
            else:
                obs = torch.split(obss, obs_dim_list, dim=2)
            act = torch.split(acts, action_dim_list, dim=-1)
            
                
            # ----------- Gradient Step ------------
            for idx, name in enumerate(player_list):
                with torch.no_grad():
                    if args.env_name in ['tennis', 'box']:
                        o = emb_model.state_embed(obs[idx], name) 
                    elif args.env_name in ['connect4']:
                        o = emb_model.state_embed(obs, name) 
                    elif args.env_name in ['push', 'tag', 'adv', 'spread']:
                        o = obs[idx]
                    else:
                        o = obs
                    
                log["avg_loss_diffusion"] += agent[name].update(o)['loss']
                diffusion_lr_scheduler[name].step()

                if n_gradient_step <= args.invdyn_gradient_steps:
                    A = []
                    for i in range(agent_nums):
                        a = act[i] if args.env_name == 'badminton' else general_one_hot(act[i], args.action_dim)
                        A.append(a)
                    a = torch.cat(A, dim=-1)
                        
                    log["avg_loss_invdyn"] += invdyn[name].update(o[:, :-1], a[:, :-1], o[:, 1:])['loss']
                    invdyn_lr_scheduler[name].step()
                    
        # ----------- Logging ------------
        if (n_gradient_step + 1) % args.log_interval == 0:
            log["gradient_steps"] = n_gradient_step + 1
            log["avg_loss_diffusion"] /= args.log_interval
            log["avg_loss_invdyn"] /= args.log_interval
            print(log)
            log = {"avg_loss_diffusion": 0., "avg_loss_invdyn": 0.}

        # ----------- Saving ------------
        if (n_gradient_step + 1) % args.save_interval == 0:
            #agent.save(save_path + f"diffusion_ckpt_{n_gradient_step + 1}.pt")
            #invdyn.save(save_path + f"invdyn_ckpt_{n_gradient_step + 1}.pt")
            for idx, name in enumerate(player_list):
                if player is not None:
                    agent[name].save(save_path + f"d_{player}_{datalen}.pt")
                    invdyn[name].save(save_path + f"i_{player}_{datalen}.pt")
                else:
                    agent[name].save(save_path + f"d_{name}_{datalen}.pt")
                    invdyn[name].save(save_path + f"i_{name}_{datalen}.pt")

        n_gradient_step += 1
        if n_gradient_step >= args.diffusion_gradient_steps:
            break


class DDAgent:
    def __init__(self, args, p_type=None, p_idx=None, player_name=None):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.player_type = p_type if p_type is not None else args.player_type
        self.player_type_idx = p_idx if p_idx is not None else args.player_type_idx
        self.agent_nums = args.agent_num + args.adv_num

        if self.env == 'badminton':
            self.player_name = player_name if player_name is not None else args.player_name
            print(f"=== Player name: {self.player_name} === ")
        else:
            self.player_name = None
        
        self.traj_length = args.traj_length
        self.strength = args.strength

        # ----- model -----
        # emb
        if self.env in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device)
            checkpoint = torch.load(f'{args.emb_path}/{self.env}/emb.pth', weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
            self.emb_model.eval()
        else:
            self.emb_model = None
        
        # model
        if args.env_name in ['tennis', 'box', 'connect4']:
            self.obs_dim = args.model.latent_dim
        else:
            self.obs_dim = list(args.state_dim_list)[self.player_type_idx]
        self.act_dim = args.action_dim
        self.temperature = args.temperature
        self.sampling_steps = args.sample_steps
        self.solver = args.model.solver
        
        nn_diffusion = DiT1d(
            self.obs_dim, 
            emb_dim=args.model.emb_dim,
            d_model=320, 
            n_heads=10, 
            depth=2, 
            timestep_emb_type="fourier")
        
        nn_condition = MLPCondition(
            in_dim=1, 
            out_dim=args.model.emb_dim, 
            hidden_dims=[args.model.emb_dim, ], 
            act=nn.SiLU(), 
            dropout=0.1)

        fix_mask = torch.zeros((args.model.horizon, self.obs_dim))
        fix_mask[0] = 1.
        loss_weight = torch.ones((args.model.horizon, self.obs_dim))
        loss_weight[1] = args.next_obs_loss_weight

        self.agent = ContinuousDiffusionSDE(
            nn_diffusion, nn_condition,
            fix_mask=fix_mask, loss_weight=loss_weight, ema_rate=args.ema_rate,
            device=self.device, predict_noise=args.predict_noise, noise_schedule="linear")

        self.invdyn = MlpInvDynamic(self.obs_dim, self.act_dim * self.agent_nums, 512, nn.Tanh(), {"lr": 2e-4}, device=self.device)
        save_path = f'./model/baseline/DBC/_weight/dd/{args.env_name}/'
        if self.env == 'badminton':
            self.agent.load(save_path + f"d_{self.player_name}_{self.strength}.pt")
            self.invdyn.load(save_path + f"i_{self.player_name}_{self.strength}.pt")
        else:
            self.agent.load(save_path + f"d_{self.player_type}_{self.strength}.pt")
            self.invdyn.load(save_path + f"i_{self.player_type}_{self.strength}.pt")
        self.agent.eval()
        self.invdyn.eval()

        print("## DD agent eval ##")
        print(f"Use Player: {self.player_type}")
        
        # record
        self.round = 1
        self.condition = torch.ones((1, 1), device=self.device)
    

    @staticmethod
    def _to_tensor(data, dtype):
        return data if isinstance(data, torch.Tensor) else torch.tensor(data, dtype=dtype)
    

    def reset(self):
        self.round = 1
        self.prev_state = None


    @torch.no_grad()
    def state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name).to(self.device)

        elif self.env == 'badminton':
            state = badminton_process_state(state)

        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)


    @torch.no_grad()
    def get_action(self, state, info=None):
        if self.env == 'badminton':
            if self.round > self.traj_length or info['round'][-1] == 1:
                self.reset()
        else:
            if not info['done'] or self.round > self.traj_length:
                self.reset()

        obs = self.state_process(state, self.player_type)
        prior = torch.zeros((1, self.traj_length, self.obs_dim), device=self.device)

        # ========== Policy ========== #
        prior[:, 0] = obs
        traj, _ = self.agent.sample(
            prior, solver=self.solver,
            n_samples=1, sample_steps=self.sampling_steps, use_ema=True,
            condition_cfg=self.condition, w_cfg=2.5, temperature=self.temperature)

        # inverse dynamic
        with torch.no_grad():
            states = obs.repeat(traj.size(0), 1)
            probs = self.invdyn.predict(states, traj[:, 1, :])
        chunks_probs = torch.chunk(probs, self.agent_nums, dim=-1)
        
        action = torch.argmax(chunks_probs[self.player_type_idx], dim=1)
        self.round += 1

        # output action
        if self.env in ['tennis', 'box']:
            return action
        elif self.env in ['connect4', 'holdem']:
            return select_action_with_mask(action, chunks_probs[self.player_type_idx].squeeze(0), info, use_random=False) ##
        elif self.env in ['badminton']:
            return badminton_action_process(chunks_probs[self.player_type_idx].squeeze(0), info, state) ##
        else:
            return action.item() # numpy
    
