import sys
import os
import warnings
warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./model'))

import torch
import torch.nn as nn
from BC.bc import BCAgent
from utils.dataloader import SADataLoader

"""
         |  agent  |   adv   | action
tag      |    14   |   16    |   5
push     |    19   |    8    |   5
adv      |    10   |    8    |   5
connect4 |  2*6*7  |  2*6*7  |   7
boxing   | 84*84*6 | 84*84*6 |   18
tennis   | 84*84*6 | 84*84*6 |   18

"""

class BCTrainer:
    def __init__(self, args, datalen):
        # --------------- params ---------------
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.player_type = args.player_type
        self.agent_num = args.agent_num
        self.adv_num = args.adv_num
        self.action_dim = args.action_dim
        self.player_list = list(args.player_list)
        
        if self.env in ['tennis', 'box']:
            self.state_dim_list = [args.model.latent_dim] * len(self.player_list)
        elif self.env in ['connect4']:
            self.state_dim_list = [args.model.latent_dim]
        else:
            self.state_dim_list = list(args.state_dim_list)
            
        if self.env in ['badminton']:
            self.action_dim_list = list(args.action_dim_list)
        else:
            self.action_dim_list = [1] * self.adv_num + [1] * self.agent_num

        self.epochs = args.train.total_epochs
        self.data_path = args.train.data_path
        self.save_weight_path = args.train.weight_save_path
        self.step_save_weight = args.train.step_save_weight
        self.datalen = datalen
       
            
        # --------------- model ---------------
        self.model = {
                name: BCAgent(args, name, idx).to(self.device)
                for idx, name in enumerate(self.player_list)
            }
        
        # --------------- dataloader ---------------
        if self.env == 'badminton':
            self.player_name = args.player_name
            path = f'{self.data_path}/badminton/{self.player_name}_dataset.csv'
            player = self.player_name
        else:
            path = f'{self.data_path}/{self.env}/{self.datalen}.pkl'
            player = None
        dl = SADataLoader(env_name = args.env_name, pkl_path = path, 
                          player_name = player,
                          batch_size = args.train.batch_size)
        loader = dl.get_dataloader()
        self.train_loader = loader["train"]
        print(f'= Dataset len: {self.datalen} =\n')
    
        
    # ========================== Train ===============================
    def train(self):
        train_losses = []

        # ----- train begin -----
        for epoch in range(1, self.epochs + 1):
            train_loss = self.train_epoch()
            # record + print loss
            train_losses.append(train_loss["epoch_loss"])
            if epoch % 10 == 0:
                print(f'---------- Epoch {epoch}/{self.epochs} ----------')
                print(f'Train Loss: {train_loss["epoch_loss"]:.4f}')
            
            if epoch % self.step_save_weight == 0:
                for name, m in self.model.items():
                    m.save_weights(self.save_weight_path, self.env, self.datalen)



    def train_epoch(self):
        for m in self.model:
            self.model[m].train()
        total_loss = {"epoch_loss": 0.0}
        
        for state, action in self.train_loader:
            state, action = map(lambda x: x.to(self.device), (state, action))

            split_fn = (lambda x: torch.unbind(x, dim=1)) if self.env in ['tennis', 'box'] \
               else (lambda x: torch.split(x, self.state_dim_list, dim=1))
            
            if self.env not in ['badminton', 'connect4', 'holdem']:
                states_list = split_fn(state)
            actions_list = torch.split(action, self.action_dim_list, dim=1)
            
            losses = 0.0
            for idx, name in enumerate(self.player_list):
                if self.env in ['badminton', 'connect4', 'holdem']:
                    losses += self.model[name](state, actions_list[idx])
                else:
                    losses += self.model[name](states_list[idx], actions_list[idx])
            
            total_loss["epoch_loss"] += losses / len(self.player_list)


        num_samples = len(self.train_loader.dataset)
        return {key: value / num_samples for key, value in total_loss.items()}