import sys
import os
import warnings
warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./model'))
sys.path.append(os.path.abspath('./model/baseline/DBC'))

import hydra
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from cleandiffuser.diffusion import DiscreteDiffusionSDE
from cleandiffuser.nn_condition import IdentityCondition
from cleandiffuser.nn_diffusion import IDQLMlp
from cleandiffuser.utils import report_parameters, IDQLQNet, IDQLVNet, general_one_hot
from model.utils.dataloader import SASARDataLoader
from model.emb.emb import StateEmb
from model.utils.agent_utils import *
from copy import deepcopy


def IDQLTrainer(args, datalen):

    print(f'Env name: {args.env_name}')
    print(f"==============================================================================")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_path = f'./model/baseline/DBC/_weight/idql/{args.env_name}/'
    
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)

    # ---------------------- Create Dataset ----------------------
    if args.env_name == 'badminton':
        path = f'{args.data_path}/badminton/{args.player_name}_dataset.csv'
        player = args.player_name
    else:
        path = f'{args.data_path}/{args.env_name}/r{datalen}.pkl'
        player = None
    dl = SASARDataLoader(env_name = args.env_name, pkl_path = path, 
                        player_name = player,
                        batch_size = args.batch_size)
    loader = dl.get_dataloader()
    train_seq = loader["train"]
    print(f'= Dataset len: {datalen} =\n')
    
    # --------------- Emb Model -----------------    
    if args.env_name in ['tennis', 'box', 'connect4']:
        emb_path = f'{args.emb_path}/{args.env_name}/emb.pth'
        emb_model = StateEmb(args).to(args.device)
        checkpoint = torch.load(emb_path, weights_only=True)
        emb_model.load_state_dict(checkpoint)
        emb_model.eval()
    
    # --------------- Network Architecture -----------------
    act_dim = args.action_dim
    
    if args.env_name in ['tennis', 'box', 'connect4']:
        obs_dim = args.model.latent_dim
    else:
        obs_dim = args.state_dim_list[args.player_type_idx]
        
    player_list = list(args.player_list)
    
    if args.env_name in ['badminton']:
        action_dim_list = list(args.action_dim_list)
    else:
        action_dim_list = [1] * args.adv_num + [1] * args.agent_num
    
    if args.env_name in ['tennis', 'box', 'connect4']:
        obs_dim_list = [args.model.latent_dim] * len(player_list)
    else:
        obs_dim_list = list(args.state_dim_list)
        
    
    nn_diffusion = IDQLMlp(
        obs_dim, act_dim, emb_dim=64,
        hidden_dim=args.actor_hidden_dim, n_blocks=args.actor_n_blocks, dropout=args.actor_dropout,
        timestep_emb_type="positional").to(args.device)
    nn_condition = IdentityCondition(dropout=0.0).to(args.device)

    print(f"======================= Parameter Report of Diffusion Model =======================")
    report_parameters(nn_diffusion)
    print(f"==============================================================================")

    # --------------- Diffusion Model Actor --------------------
    actor = DiscreteDiffusionSDE(
        nn_diffusion, nn_condition, predict_noise=args.predict_noise, optim_params={"lr": args.actor_learning_rate},
        x_max=+1. * torch.ones((1, act_dim), device=args.device),
        x_min=-1. * torch.ones((1, act_dim), device=args.device),
        diffusion_steps=args.diffusion_steps, ema_rate=args.ema_rate)

    # ------------------ Critic ---------------------
    iql_q = IDQLQNet(obs_dim, act_dim, hidden_dim=args.critic_hidden_dim).to(args.device)
    iql_q_target = deepcopy(iql_q).requires_grad_(False).eval()
    iql_v = IDQLVNet(obs_dim, hidden_dim=args.critic_hidden_dim).to(args.device)

    q_optim = torch.optim.Adam(iql_q.parameters(), lr=args.critic_learning_rate)
    v_optim = torch.optim.Adam(iql_v.parameters(), lr=args.critic_learning_rate)

    # ---------------------- Training ----------------------
    n_gradient_step = 0
    while True:

        actor_lr_scheduler = CosineAnnealingLR(actor.optimizer, T_max=args.gradient_steps)
        q_lr_scheduler = CosineAnnealingLR(q_optim, T_max=args.gradient_steps)
        v_lr_scheduler = CosineAnnealingLR(v_optim, T_max=args.gradient_steps)

        actor.train()
        iql_q.train()
        iql_v.train()

        log = {"bc_loss": 0., "q_loss": 0., "v_loss": 0.}
        
        for prev_states, prev_actions, states, actions, rewards in train_seq:
            prev_states, prev_actions, states, actions, rewards = prev_states.to(device), prev_actions.to(device), states.to(device), actions.to(device), rewards.to(device)

            obs, next_obs = prev_states, states
            act = actions
            rew = rewards
            tml = torch.zeros((act.size(0), 1)).to(device)
            
            if args.env_name in ['tennis', 'box']:
                next_obs = torch.unbind(next_obs, dim=1)
                obs = torch.unbind(obs, dim=1)
            elif args.env_name in ['badminton', 'connect4', 'holdem']:
                next_obs = next_obs
                obs = obs
            else:
                next_obs = torch.split(next_obs, obs_dim_list, dim=1)
                obs = torch.split(obs, obs_dim_list, dim=1)
                
            act = torch.split(act, action_dim_list, dim=1)
            
            if args.env_name in ['badminton', 'connect4', 'holdem']:
                obs = obs
                next_obs = next_obs
                rew = rew
            else:
                obs = obs[args.player_type_idx]
                next_obs = next_obs[args.player_type_idx]
                rew = torch.split(rew, action_dim_list, dim=1)
            
            act = act[args.player_type_idx]
            rew = rew[args.player_type_idx]
            
            with torch.no_grad():
                if args.env_name in ['tennis', 'box', 'connect4']:
                    obs = emb_model.state_embed(obs, args.player_list[args.player_type_idx])
                    next_obs = emb_model.state_embed(next_obs, args.player_list[args.player_type_idx])
            
            act = act if args.env_name == 'badminton' else general_one_hot(act, args.action_dim).to(device)
            act = act.squeeze(1) if act.dim() > 2 else act
            #act = act.view(act.shape[0], -1).unsqueeze(1) if args.agent_num > 1 else act.unsqueeze(1)
            

            # -- IQL Training
            if n_gradient_step % 2 == 0:

                q = iql_q_target(obs, act)
                v = iql_v(obs)
                v_loss = (torch.abs(args.iql_tau - ((q - v) < 0).float()) * (q - v) ** 2).mean()

                v_optim.zero_grad()
                v_loss.backward()
                v_optim.step()

                with torch.no_grad():
                    td_target = rew + args.discount * (1 - tml) * iql_v(next_obs)
        
                q1, q2 = iql_q.both(obs, act)
                q_loss = ((q1 - td_target) ** 2 + (q2 - td_target) ** 2).mean()
                q_optim.zero_grad()
                q_loss.backward()
                q_optim.step()

                q_lr_scheduler.step()
                v_lr_scheduler.step()

                for param, target_param in zip(iql_q.parameters(), iql_q_target.parameters()):
                    target_param.data.copy_(0.995 * param.data + (1 - 0.995) * target_param.data)

            # -- Policy Training
            bc_loss = actor.update(act, obs)["loss"]
            actor_lr_scheduler.step()

            # # ----------- Logging ------------
            log["bc_loss"] += bc_loss#.item()
            log["q_loss"] += q_loss.item()
            log["v_loss"] += v_loss.item()

            if (n_gradient_step + 1) % args.log_interval == 0:
                log["gradient_steps"] = n_gradient_step + 1
                log["bc_loss"] /= args.log_interval
                log["q_loss"] /= args.log_interval
                log["v_loss"] /= args.log_interval
                print(log)
                log = {"bc_loss": 0., "q_loss": 0., "v_loss": 0.}
        
        # ----------- Saving ------------
        if (n_gradient_step + 1) % args.save_interval == 0:
            #actor.save(save_path + f"diffusion_ckpt_{n_gradient_step + 1}.pt")
            actor.save(save_path + f"diffusion_ckpt_latest.pt")
            torch.save({
                "iql_q": iql_q.state_dict(),
                "iql_q_target": iql_q_target.state_dict(),
                "iql_v": iql_v.state_dict(),
            }, save_path + f"iql_ckpt_latest.pt")
        
        n_gradient_step += 1
        if n_gradient_step > args.gradient_steps:
            # finish
            break



class IDQLAgent:
    def __init__(self, args, p_type=None, p_idx=None, player_name=None):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.player_type = p_type if p_type is not None else args.player_type
        self.player_type_idx = p_idx if p_idx is not None else args.player_type_idx

        if self.env == 'badminton':
            self.player_name = player_name if player_name is not None else args.player_name
            print(f"=== Player name: {self.player_name} === ")
        else:
            self.player_name = None

        self.traj_length = args.traj_length
        self.strength = args.strength
        
        self.action_dim = args.action_dim
        self.solver = args.solver
        self.sampling_steps = args.sampling_steps
        self.weight_temperature = args.weight_temperature
        self.temperature = args.temperature

        # ----- model -----
        # emb
        if self.env in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device)
            checkpoint = torch.load(f'{args.emb_path}/{self.env}/emb.pth', weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
            self.emb_model.eval()
        else:
            self.emb_model = None

        if self.env not in ['tennis', 'box', 'connect4']:
            self.obs_dim = list(args.state_dim_list)[self.player_type_idx] * args.obs_steps
        else:
            self.obs_dim = args.model.latent_dim * args.obs_steps

        self.act_dim = args.action_dim
        if args.env_name in ['tennis', 'box', 'connect4']:
            self.obs_dim = args.model.latent_dim
        else:
            self.obs_dim = args.state_dim_list[args.player_type_idx]
        
        self.prev_state = None
        self.round = 1
        
        # model
        nn_diffusion = IDQLMlp(
            self.obs_dim, self.act_dim, emb_dim=64,
            hidden_dim=args.actor_hidden_dim, n_blocks=args.actor_n_blocks, dropout=args.actor_dropout,
            timestep_emb_type="positional").to(args.device)
        nn_condition = IdentityCondition(dropout=0.0).to(args.device)
        
        self.actor = DiscreteDiffusionSDE(
            nn_diffusion, nn_condition, predict_noise=args.predict_noise, optim_params={"lr": args.actor_learning_rate},
            x_max=+1. * torch.ones((1, self.act_dim), device=args.device),
            x_min=-1. * torch.ones((1, self.act_dim), device=args.device),
            diffusion_steps=args.diffusion_steps, ema_rate=args.ema_rate)
        
        self.iql_q = IDQLQNet(self.obs_dim, self.act_dim, hidden_dim=args.critic_hidden_dim).to(args.device)
        self.iql_q_target = deepcopy(self.iql_q).requires_grad_(False).eval()
        self.iql_v = IDQLVNet(self.obs_dim, hidden_dim=args.critic_hidden_dim).to(args.device)
        
        # load model
        print("## IDQL agent eval ##")
        print(f"Use Player: {self.player_type}")
        
        if self.player_name is None:
            save_path = f'./model/baseline/DBC/_weight/idql/{self.env}/'
        else:
            save_path = f'./model/baseline/DBC/_weight/idql/{self.env}/'
        self.actor.load(save_path + f"diffusion_ckpt_latest.pt")
        
        critic_ckpt = torch.load(save_path + f"iql_ckpt_latest.pt")
        self.iql_q.load_state_dict(critic_ckpt["iql_q"])
        self.iql_q_target.load_state_dict(critic_ckpt["iql_q_target"])
        self.iql_v.load_state_dict(critic_ckpt["iql_v"])

        self.actor.eval()
        self.iql_q.eval()
        self.iql_v.eval()


    @staticmethod
    def _to_tensor(data, dtype):
        return data if isinstance(data, torch.Tensor) else torch.tensor(data, dtype=dtype)
    

    def reset(self):
        self.round = 1
        self.prev_state = None


    @torch.no_grad()
    def state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name).to(self.device)

        elif self.env == 'badminton':
            state = badminton_process_state(state)

        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)


    @torch.no_grad()
    def get_action(self, state, info=None):
        if self.env == 'badminton':
            if self.round > self.traj_length or info['round'][-1] == 1:
                self.reset()
        else:
            if not info['done'] or self.round > self.traj_length:
                self.reset()
            
        # state
        obs = self.state_process(state, self.player_type).unsqueeze(0)
        prior = torch.zeros((1, self.act_dim), device=self.device)

        # ========== Policy ========== #
        
        act, log = self.actor.sample(
            prior,
            solver=self.solver,
            n_samples=1,
            sample_steps=self.sampling_steps,
            condition_cfg=obs, w_cfg=1.0,
            use_ema=True, temperature=self.temperature)

        with torch.no_grad():
            q = self.iql_q_target(obs, act)
            v = self.iql_v(obs)
            adv = (q - v)
            adv = adv.view(-1, 1, 1)

            w = torch.softmax(adv * self.weight_temperature, 1)
            act = act.view(-1, 1, self.act_dim)

            p = w / w.sum(1, keepdim=True)

            indices = torch.multinomial(p.squeeze(-1), 1).squeeze(-1)
            sampled_act = act[torch.arange(act.shape[0]), indices].cpu().numpy()

        probs = sampled_act.squeeze(0)
        probs = torch.tensor(probs).unsqueeze(0).to(self.device)
        action = torch.argmax(probs, dim=1)
        self.round += 1

        # output action
        if self.env in ['tennis', 'box']:
            return action
        elif self.env in ['connect4', 'holdem']:
            return select_action_with_mask(action, probs.squeeze(0), info, use_random=False) ##
        elif self.env in ['badminton']:
            return badminton_action_process(probs.squeeze(0), info, state) ##
        else:
            return action.item() # numpy