import os
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from demo.ddpm import DDPM
from emb.emb import StateEmb
from utils.plot import Plot_losses
from utils.dataloader import SADataLoader
from utils.logger import Logger
from demo.helper import set_seed
import time


class PolicyTrainer:
    def __init__(self, args, datalen):
        # --------------- params ---------------
        # env params
        set_seed(args.seed)
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.agent_num = args.agent_num
        self.adv_num = args.adv_num
        self.action_dim = args.action_dim
        self.player_type = args.player_type
        self.player_idx = args.player_type_idx
        self.player_list = list(args.player_list)
        
        # state dim list
        if self.env in ['tennis', 'box', 'connect4']:
            self.state_dim_list = [args.model.latent_dim] * len(self.player_list)
        else:
            self.state_dim_list = list(args.state_dim_list)
        
        # action dim list
        if self.env in ['badminton']:
            self.action_dim_list = list(args.action_dim_list)
        else:
            self.action_dim_list = [1] * self.adv_num + [1] * self.agent_num
            
        # train params
        self.epochs = args.train.policy_total_epochs
        self.datalen = datalen
        self.data_path = args.train.data_path
        self.weight_save_path = args.train.weight_save_path
        self.plot_save_path = args.train.plot_save_path
        self.step_epoch_log = args.train.step_epoch_log
        self.step_save_weight = args.train.step_save_weight

        # --------------- model ---------------
        if self.env in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device) # emb model
            self.load_weights(emb_path = args.emb_path)
            self.emb_model.eval()
        self.model = DDPM(args).to(self.device) # ddpm
        self.lr_scheduler = CosineAnnealingLR(self.model.optimizer, self.epochs)
        
        # --------------- other ---------------
        # dataloader
        if self.env == 'badminton':
            self.player_name = args.player_name
            path = f'{self.data_path}/badminton/{self.player_name}_dataset.csv'
            player = self.player_name
        else:
            path = f'{self.data_path}/{self.env}/{self.datalen}.pkl'
            player = None
        dl = SADataLoader(env_name = args.env_name, pkl_path = path, 
                          player_name = player,
                          batch_size = args.train.batch_size)
        loader = dl.get_dataloader()
        self.train_loader = loader['train']
        print(f'= Dataset len: {self.datalen} =\n')
        
        # logger
        m = f"train data len: {self.datalen}"
        self.logger = Logger(log_file=f"tran", 
                             log_dir=f"{args.train.logs_path}/{args.env_name}",
                             initial_message=f"Epoch: {self.epochs}, batch size: {args.train.batch_size}, lr: {args.train.lr}, {m}")

    
    # ========================== Train ===============================
    def train(self):
        train_losses, val_losses = [], []
        start_time = time.time()

        # ----- train begin -----
        for epoch in range(1, self.epochs + 1):
            train_loss = self.train_epoch()
            self.lr_scheduler.step()
            
            # record + print loss
            train_losses.append(train_loss["epoch_loss"])
            if epoch % 5 == 0:
                print(f'---------- Epoch {epoch}/{self.epochs} ----------')
                print(f'Train Loss: {train_loss["epoch_loss"]:.5f}')
            
            # logger
            if epoch % self.step_epoch_log == 0:
                self.logger.log_ddpm_loss(epoch, train_loss["epoch_loss"])

            # save weight
            if epoch % self.step_save_weight == 0:
                self.save_weights(self.weight_save_path)
        # ----- train ends -----
        end_time = time.time()
        total_time = end_time - start_time
        self.logger.log_train_time(total_time)
        
        # plot
        path = f'{self.plot_save_path}/{self.env}'
        if os.path.exists(path) is False:
            os.makedirs(path)
        Plot_losses(train_losses, val_losses, title = f"policy", path = f'{path}/policy.png')


    @torch.no_grad()
    def state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = self.emb_model.state_embed(state, name)
            return state
        else:
            return state
    
    
    @torch.no_grad()
    def action_process(self, x, std=1.0):
        """
        x: Tensor of shape [batch_size], with discrete class indices
        Returns: Tensor of shape [batch_size, num_classes], with smoothed one-hot
        """
        if x.dim() == 1:
            x = x.unsqueeze(1)  # shape: [batch_size, 1]
            
        if self.env not in ['badminton']:
            class_range = torch.arange(self.action_dim, device=self.device).float()  # [0, 1, ..., 5]
            # Gaussian kernel
            dist = -(class_range - x)**2 / (2 * std**2)  # shape: [batch_size, num_classes]
            soft = F.softmax(dist, dim=-1)  # softmax

            return soft
        
        else:
            return x

    
    def train_epoch(self):
        self.model.train()
        total_loss = {"epoch_loss" : 0.0}

        for states, actions in self.train_loader:
            states, actions = states.to(self.device), actions.to(self.device)
            
            if self.env in ['tennis', 'box']:
                states = torch.unbind(states, dim=1)
            elif self.env in ['badminton', 'connect4', 'holdem']:
                states = states
            else:
                states = torch.split(states, self.state_dim_list, dim=1)
            actions = torch.split(actions, self.action_dim_list, dim=1)
            
            if self.env in ['badminton', 'connect4', 'holdem']:
                cond = {'state': self.state_process(states, name=self.player_type)}
            else:
                cond = {'state': self.state_process(states[self.player_idx], name=self.player_type)}
            x = {name: self.action_process(actions[i]) for i, name in enumerate(self.player_list)}
            total_loss["epoch_loss"] += self.model(x, cond)


        return {key: value / len(self.train_loader.dataset) for key, value in total_loss.items()}
    

    # =========================== Load / Save weights ===========================
    def save_weights(self, save_weight_path):
        path = f'{save_weight_path}/{self.env}'
        if os.path.exists(path) is False:
            os.makedirs(path)
        if self.env in ['badminton']:
            path = f'{path}/{self.player_name}_policy_{self.datalen}.pth'
        else:
            path = f'{path}/{self.player_type}_policy_{self.datalen}.pth'
        self.model.save_weights(path)
    

    def load_weights(self, emb_path=None):
        if emb_path:
            path = f'{emb_path}/{self.env}/emb.pth'
            checkpoint = torch.load(path, weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
        else:
            raise ValueError("It must be emb_path")