import sys
import os
import warnings
warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./model'))

import torch
import torch.nn as nn
import torch.nn.functional as F
from baseline.DBC2.ddpm import MLPDiffusion, diffusion_loss_fn, sigmoid_beta_schedule
from utils.dataloader import SADataLoader
from utils.agent_utils import *
from emb.emb import StateEmb


class DBCAgent(nn.Module):
    def __init__(self, args, p_type=None, p_idx=None, eval=False):
        super(DBCAgent, self).__init__()
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.player_type = p_type if p_type is not None else args.player_type
        self.player_type_idx = p_idx if p_idx is not None else args.player_type_idx
        self.strength = args.strength
        
        # params
        self.env = args.env_name
        self.player_name = args.player_name if self.env == 'badminton' else None
        if self.env in ['tennis', 'box', 'connect4']:
            self.state_dim = args.model.latent_dim
        else:
            self.state_dim = list(args.state_dim_list)[self.player_type_idx]
        self.action_dim = args.action_dim
        
        # params
        self.hidden_dim = args.bc.hidden_dim
        self.agent_num = args.agent_num
        cont = True if self.env in ['badminton'] else False
        self.policy = BCModule(state_dim=self.state_dim, 
                                   action_dim=self.action_dim, 
                                   hidden_dim=self.hidden_dim, 
                                   device=self.device,
                                   continuous=cont
                                   ).to(self.device)
        
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=args.bc.lr * 30)
        
        if self.env in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device) # emb model
            self.load_weights(emb_path=args.emb_path)
            self.emb_model.eval()

        if eval:
            print("## DBC Agent eval ##")
            self.load_weights(policy_path=args.policy_path)
            self.policy.eval()
        else:
            print("## DBC Agent train ##")
    
    
    @torch.no_grad()
    def eval_state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name)

        elif self.env == 'badminton':
            state = badminton_process_state(state)

        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)


    def forward(self, state, action):
        recon_probs, pred_action = self.policy(state)

        if self.env in ['badminton']:
            loss = F.mse_loss(recon_probs, action)
        else:
            loss = F.nll_loss(recon_probs.squeeze(-1), action.squeeze(-1))
            
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        self.optimizer.step()
        
        if self.env == 'badminton':
            return recon_probs, loss.item()
        else:
            return pred_action, loss.item()
    
    
    @torch.no_grad()
    def get_action(self, state, info=None):
        s = self.eval_state_process(state, self.player_type)
        probs, action = self.policy(s)
        
        if self.agent_num == 1:
            probs = probs.squeeze(-1)
            action = action.squeeze(-1)
        
        if self.env in ['connect4', 'holdem']:
            return select_action_with_mask(action, probs, info, use_random=False) ##
            
        if self.env in ['tennis', 'pong', 'box']:
            return action.unsqueeze(0) # tensor
        
        elif self.env == 'badminton':
            return badminton_action_process(probs.squeeze(0), info, state)
        else:
            return action.item() # numpy


    # Load / Save weights
    def save_weights(self, save_weight_path, env_name, datalen):
        path = f'{save_weight_path}'
        if os.path.exists(path) is False:
            os.makedirs(path)
        if self.env in ['badminton']:
            save_path = f'{path}/bc_{self.player_name}_{self.player_type}_{datalen}.pth'
        else:
            save_path = f'{path}/bc_{self.player_type}_{datalen}.pth'
            
        torch.save(self.policy.state_dict(), save_path)
        
        
    def load_weights(self, emb_path=None, policy_path=None):   
        if emb_path:
            e_path = f'{emb_path}/{self.env}/emb.pth'
            checkpoint = torch.load(e_path, weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
            
        if policy_path:
            if self.env == 'badminton':
                path = f'{policy_path}/{self.env}/bc_{self.player_name}_{self.player_type}_{self.strength}.pth' ##
            else:
                path = f'{policy_path}/{self.env}/bc_{self.player_type}_{self.strength}.pth' ##
            checkpoint = torch.load(path, weights_only=True)
            self.policy.load_state_dict(checkpoint)
        

# module
class BCModule(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, device, temperature=1.0, continuous=False):
        super(BCModule, self).__init__()
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.continuous = continuous
        self.temperature = temperature
        self.model_fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), 
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)).to(self.device)
        
        self.output_fc = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, action_dim))
        

    def forward(self, state):
        state = state.to(self.device)
        x = self.model_fc(state)
        logits = self.output_fc(x)
        
        if not self.continuous:
            log_probs = F.log_softmax(logits / self.temperature, dim=-1)
            act = torch.argmax(log_probs, dim=-1)
            return log_probs, act
        
        else:
            log_probs = F.log_softmax(logits[:12] / self.temperature, dim=-1)
            act = torch.argmax(log_probs, dim=-1)
            return logits, act
        

class BCTrainer:
    def __init__(self, args, datalen):
        # --------------- params ---------------
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.player_type = args.player_type
        self.agent_num = args.agent_num
        self.adv_num = args.adv_num
        self.action_dim = args.action_dim
        self.player_list = list(args.player_list)
        self.epochs = args.bc.num_epoch
        self.data_path = args.data_path
        self.save_weight_path = f'./model/baseline/DBC2/_weight/{args.env_name}'
        self.datalen = datalen

        self.num_steps = args.ddpm.num_steps
        betas = sigmoid_beta_schedule(self.num_steps)
        alphas = 1 - betas
        alphas_prod = torch.cumprod(alphas, 0).to(self.device)
        self.alphas_bar_sqrt = torch.sqrt(alphas_prod).to(self.device)
        self.one_minus_alphas_bar_log = torch.log(1 - alphas_prod).to(self.device)
        self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod).to(self.device)
        
        # state_dim_list
        if self.env in ['tennis', 'box', 'connect4']:
            self.state_dim_list = [args.model.latent_dim] * len(self.player_list)
        else:
            self.state_dim_list = list(args.state_dim_list)
        
        # action_dim_list
        if self.env in ['badminton']:
            self.action_dim_list = list(args.action_dim_list)
        else:
            self.action_dim_list = [1] * self.adv_num + [1] * self.agent_num
            
        # input_dim_list
        self.input_dim_list = [x + self.action_dim for x in self.state_dim_list]
            
        # --------------- model ---------------
        self.model = {
                name: DBCAgent(args, name, idx).to(self.device)
                for idx, name in enumerate(self.player_list)
            }
        
        self.ddpm = {
                name: MLPDiffusion(
                    self.num_steps,
                    input_dim=self.input_dim_list[idx],
                    num_units=args.ddpm.hidden_dim,
                    depth=args.ddpm.depth,
                ).to(self.device)
                    for idx, name in enumerate(self.player_list)
            }
        
        for name in self.player_list:
            checkpoint = torch.load(f'{self.save_weight_path}/{name}_ddpm.pth', weights_only=True)
            self.ddpm[name].load_state_dict(checkpoint)
            self.ddpm[name].eval()

        if self.env in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device) # emb model
            checkpoint = torch.load(f'{args.emb_path}/{self.env}/emb.pth', weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
            self.emb_model.eval()

        # --------------- dataloader ---------------
        if self.env == 'badminton':
            self.player_name = args.player_name
            path = f'{self.data_path}/badminton/{self.player_name}_dataset.csv'
            player = self.player_name
        else:
            path = f'{self.data_path}/{self.env}/{self.datalen}.pkl'
            player = None
        dl = SADataLoader(env_name = args.env_name, pkl_path = path, 
                          player_name = player,
                          batch_size = args.bc.batch_size)
        loader = dl.get_dataloader()
        self.train_loader = loader["train"]
        print(f'= Dataset len: {self.datalen} =\n')
    

    @torch.no_grad()
    def state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name)
        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)
        
    # ========================== Train ===============================
    def train(self):
        train_losses = []

        # ----- train begin -----
        for epoch in range(1, self.epochs + 1):
            train_loss = self.train_epoch()
            # record + print loss
            train_losses.append(train_loss["epoch_loss"])
            if epoch % 10 == 0:
                print(f'---------- Epoch {epoch}/{self.epochs} ----------')
                print(f'Train Loss: {train_loss["epoch_loss"]:.4f}')
                for name, m in self.model.items():
                    m.save_weights(self.save_weight_path, self.env, self.datalen)

    
    def ddpm_loss(self, x_true, x_pred, name):
        loss1 = diffusion_loss_fn(self.ddpm[name], x_true, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.num_steps)
        loss2 = diffusion_loss_fn(self.ddpm[name], x_pred, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.num_steps)
        loss_dm = max(loss2 - loss1, 0)
        return loss_dm


    def train_epoch(self):
        for m in self.model:
            self.model[m].train()
        total_loss = {"epoch_loss": 0.0}
        
        for state, action in self.train_loader:
            state, action = map(lambda x: x.to(self.device), (state, action))

            split_fn = (lambda x: torch.unbind(x, dim=1)) if self.env in ['tennis', 'box'] \
               else (lambda x: torch.split(x, self.state_dim_list, dim=1))
            
            if self.env not in ['badminton', 'connect4', 'holdem']:
                states_list = split_fn(state)
            actions_list = torch.split(action, self.action_dim_list, dim=1)
            
            losses = 0.0
            for idx, name in enumerate(self.player_list):
                if self.env in ['badminton', 'connect4', 'holdem']:
                    obs = self.state_process(state, name)
                else:
                    obs = self.state_process(states_list[idx], name)
                action, loss_bc = self.model[name](obs, actions_list[idx])

                if self.env == 'badminton':
                    x_true = torch.cat([obs, actions_list[idx]], dim=-1).to(self.device)
                    x_pred = torch.cat([obs, action], dim=-1).to(self.device)
                else:
                    x_true = torch.cat([obs, action_one_hot(actions_list[idx].squeeze(-1), num_classes = self.action_dim)], dim=-1).to(self.device)
                    x_pred = torch.cat([obs, action_one_hot(action.squeeze(-1), num_classes = self.action_dim)], dim=-1).to(self.device)
                
                loss_dm = self.ddpm_loss(x_true, x_pred, name)
                losses += (loss_bc + 0.5 * loss_dm)
            
            total_loss["epoch_loss"] += losses / len(self.player_list)


        num_samples = len(self.train_loader.dataset)
        return {key: value / num_samples for key, value in total_loss.items()}