import sys
import os
import warnings
warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./model'))
sys.path.append(os.path.abspath('./model/baseline/DBC'))

import numpy as np
import torch
import torch.nn as nn
from cleandiffuser.nn_condition import MLPCondition
from cleandiffuser.nn_diffusion import DiT1d, PearceMlp
from cleandiffuser.diffusion.ddpm import DDPM
from cleandiffuser.utils import report_parameters, general_one_hot
from model.utils.dataloader import SASADataLoader
from model.emb.emb import StateEmb
from model.utils.agent_utils import *


def DPTrainer(args, datalen):
    print(f'Env name: {args.env_name}')
    print(f"==============================================================================")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_path = f'./model/baseline/DBC/_weight/dp/{args.env_name}/'
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)

    # ---------------------- Create Dataset ----------------------
    if args.env_name == 'badminton':
        path = f'{args.data_path}/badminton/{args.player_name}_dataset.csv'
        player = args.player_name
    else:
        path = f'{args.data_path}/{args.env_name}/{datalen}.pkl'
        player = None
    dl = SASADataLoader(env_name = args.env_name, pkl_path = path, 
                        player_name = player,
                        batch_size = args.batch_size)
    loader = dl.get_dataloader()
    train_seq = loader["train"]
    print(f'= Dataset len: {datalen} =\n')
    
    # --------------- Emb Model -----------------    
    if args.env_name in ['tennis', 'box', 'connect4']:
        emb_path = f'{args.emb_path}/{args.env_name}/emb.pth'
        emb_model = StateEmb(args).to(args.device)
        checkpoint = torch.load(emb_path, weights_only=True)
        emb_model.load_state_dict(checkpoint)
        emb_model.eval()
    
    # --------------- Create Diffusion Model -----------------    
    act_dim = args.action_dim
    player_list = list(args.player_list)
    if args.env_name in ['badminton']:
        action_dim_list = list(args.action_dim_list)
    else:
        action_dim_list = [1] * args.adv_num + [1] * args.agent_num
    
    if args.env_name in ['tennis', 'box', 'connect4']:
        obs_dim_list = [args.model.latent_dim] * len(player_list)
    else:
        obs_dim_list = list(args.state_dim_list)

    nn_diffusion = DiT1d(
        act_dim, emb_dim=256, d_model=384, n_heads=6, depth=4, timestep_emb_type="fourier").to(args.device)
    #nn_diffusion = PearceMlp(act_dim).to(args.device)
    
    nn_condition = nn.ModuleDict({
        name: MLPCondition(in_dim = args.obs_steps * obs_dim_list[idx], 
                           out_dim=256, 
                           hidden_dims=[256, ], 
                           act=nn.ReLU(), 
                           dropout=0.25).to(args.device)
        for idx, name in enumerate(player_list)
        })
    
    
    print(f"======================= Parameter Report of Diffusion Model =======================")
    report_parameters(nn_diffusion, topk=3)
    print(f"==============================================================================")

    from cleandiffuser.diffusion.ddpm import DDPM
    agent = {
        name: DDPM(
            nn_diffusion=nn_diffusion, 
            nn_condition=nn_condition[name], 
            device=args.device,
            diffusion_steps=args.sample_steps, 
            ema_rate=0.9999,
            optim_params={"lr": args.lr})
        for idx, name in enumerate(player_list)
        }
    
    # ----------------- Training ----------------------
    n_gradient_step = 0

    while True:
        for prev_states, prev_actions, states, actions in train_seq:
            prev_states, prev_actions, states, actions = prev_states.to(device), prev_actions.to(device), states.to(device), actions.to(device)
            
            # diffusionBC
            # |o|o|
            # | |a|
            losses = 0.0
            if args.env_name in ['tennis', 'box']:
                states = torch.unbind(states, dim=1)
                prev_states = torch.unbind(prev_states, dim=1)
            elif args.env_name in ['badminton', 'connect4', 'holdem']:
                states = states
                prev_states = prev_states
            else:
                states = torch.split(states, obs_dim_list, dim=1)
                prev_states = torch.split(prev_states, obs_dim_list, dim=1)
            actions = torch.split(actions, action_dim_list, dim=1)

        
            for idx, name in enumerate(player_list):
                naction = actions[idx] if args.env_name == 'badminton' else general_one_hot(actions[idx], args.action_dim).to(args.device)
                naction = naction.squeeze(1) if naction.dim() > 2 else naction
                naction = naction.view(naction.shape[0], -1).unsqueeze(1) if args.agent_num > 1 else naction.unsqueeze(1)

                if args.env_name == 'connect4':
                    prev_state = emb_model.state_embed(prev_states, name)
                    state = emb_model.state_embed(states, name)
                    
                elif args.env_name in ['badminton', 'holdem']:
                    prev_state, state = prev_states, states
                    
                elif args.env_name in ['tennis', 'box']:
                    prev_state = emb_model.state_embed(prev_states[idx], name)
                    state = emb_model.state_embed(states[idx], name)
                else:
                    prev_state, state = prev_states[idx], states[idx]
                        
                condition = torch.stack([prev_state, state], dim=1)
                condition = condition.flatten(start_dim=1)
                losses += agent[name].update(naction, condition)['loss']
                
        losses /= (len(train_seq.dataset) * len(player_list))

        if n_gradient_step % args.save_freq == 0:
            print(f"Loss: {losses}")
            for name, model in agent.items():
                if player is not None:
                    p = f"{player}_{name}_{datalen}.pt"
                else:
                    p = f"{name}_{datalen}.pt"
                model.save(save_path + p)
            
        n_gradient_step += 1
        if n_gradient_step > args.gradient_steps:
            # finish
            break



class DPAgent:
    def __init__(self, args, p_type=None, p_idx=None, player_name=None):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.player_type = p_type if p_type is not None else args.player_type
        self.player_type_idx = p_idx if p_idx is not None else args.player_type_idx

        if self.env == 'badminton':
            self.player_name = player_name if player_name is not None else args.player_name
            print(f"=== Player name: {self.player_name} === ")
        else:
            self.player_name = None

        self.traj_length = args.traj_length
        self.strength = args.strength
        
        self.action_dim = args.action_dim
        self.x_max = args.model.x_max
        self.x_min = args.model.x_min
        self.sampling_steps = args.sample_steps
        self.extra_sample_steps = args.extra_sample_steps
        self.solver = None

        # ----- model -----
        # emb
        if self.env in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device)
            checkpoint = torch.load(f'{args.emb_path}/{self.env}/emb.pth', weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
            self.emb_model.eval()
        else:
            self.emb_model = None

        if self.env not in ['tennis', 'box', 'connect4']:
            self.obs_dim = list(args.state_dim_list)[self.player_type_idx] * args.obs_steps
        else:
            self.obs_dim = args.model.latent_dim * args.obs_steps

        self.act_dim = args.action_dim
        self.prev_state = None
        self.round = 1
        
        # model
        nn_diffusion = DiT1d(
            self.act_dim, emb_dim=256, d_model=384, n_heads=6, depth=4, timestep_emb_type="fourier").to(args.device)
        nn_condition = MLPCondition(
            in_dim=self.obs_dim, out_dim=256, hidden_dims=[256, ], act=nn.ReLU(), dropout=0.25).to(args.device)
        
        x_max = torch.ones((1, 1, self.action_dim), device=self.device) * self.x_max # 3.0
        x_min = torch.ones((1, 1, self.action_dim), device=self.device) * self.x_min # 3.0
        self.agent = DDPM(
            nn_diffusion=nn_diffusion, nn_condition=nn_condition, device=args.device,
            diffusion_steps=args.sample_steps, ema_rate=0.99, x_max=x_max, x_min=x_min)
        
        # load model
        print("## DBC agent eval ##")
        print(f"Use Player: {self.player_type}")
        if self.player_name is None:
            save_path = f'./model/baseline/DBC/_weight/dp/{self.env}/{self.player_type}_{self.strength}.pt'
        else:
            save_path = f'./model/baseline/DBC/_weight/dp/{self.env}/{self.player_name}_{self.player_type}_{self.strength}.pt'
        self.agent.load(save_path) ##
        self.agent.eval()
        

    @staticmethod
    def _to_tensor(data, dtype):
        return data if isinstance(data, torch.Tensor) else torch.tensor(data, dtype=dtype)
    

    def reset(self):
        self.round = 1
        self.prev_state = None


    @torch.no_grad()
    def state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name).to(self.device)

        elif self.env == 'badminton':
            state = badminton_process_state(state)

        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)


    @torch.no_grad()
    def get_action(self, state, info=None):
        if self.env == 'badminton':
            if self.round > self.traj_length or info['round'][-1] == 1:
                self.reset()
        else:
            if not info['done'] or self.round > self.traj_length:
                self.reset()
            
        # state
        if isinstance(state, torch.Tensor):
            prev_state = torch.zeros_like(state).to(self.device) if self.round == 1 else self.prev_state
        elif isinstance(state, tuple): # badminton
            prev_state = (0, (0, 0), (0, 0), (0, 0)) if self.round == 1 else self.prev_state
        else:
            prev_state = np.zeros_like(state) if self.round == 1 else self.prev_state
        self.prev_state = state
        
        prev_obs = self.state_process(prev_state, self.player_type)
        obs = self.state_process(state, self.player_type)
        prior = torch.zeros((1, 1, self.act_dim), device=self.device)

        # ========== Policy ========== #
        condition = torch.cat([prev_obs, obs], dim=-1).to(self.device)

        naction_prob, _ = self.agent.sample_x(prior=prior, n_samples=1, sample_steps=self.sampling_steps, solver=self.solver,
                                        condition_cfg=condition, w_cfg=1.0, use_ema=True, extra_sample_steps=self.extra_sample_steps)
        #print(naction_prob)
        probs = naction_prob.squeeze(0)
        action = torch.argmax(probs, dim=1)
        self.round += 1

        # output action
        if self.env in ['tennis', 'box']:
            return action
        elif self.env in ['connect4', 'holdem']:
            return select_action_with_mask(action, probs.squeeze(0), info, use_random=False) ##
        elif self.env in ['badminton']:
            return badminton_action_process(probs.squeeze(0), info, state) ##
        else:
            return action.item() # numpy