import os
import torch
import torch.optim as optim
from emb.emb import StateEmb
from utils.plot import Plot_losses
from utils.dataloader import SADataLoader

"""
         |  agent  |   adv   | action
tag      |    14   |   16    |   5
push     |    19   |    8    |   5
adv      |    10   |    8    |   5
connect4 |  2*6*7  |  2*6*7  |   7
boxing   | 84*84*6 | 84*84*6 |   18
tennis   | 84*84*6 | 84*84*6 |   18
pong     | 84*84*6 | 84*84*6 |   18

"""

class EmbTrainer:
    def __init__(self, args, datalen):
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

        # model
        self.model = StateEmb(args).to(self.device)

        # dataloader
        dl = SADataLoader(env_name = args.env_name, pkl_path = f"{args.train.data_path}/{args.env_name}/{datalen}.pkl", 
                           batch_size = args.train.batch_size)

        loader = dl.get_dataloader()
        self.train_loader = loader["train"]
        self.val_loader = loader["val"]

        # params
        self.env = args.env_name
        self.epochs = args.train.emb_total_epochs
        self.optimizer = optim.AdamW(self.model.parameters(), lr=args.train.lr, weight_decay=args.train.weight_decay)


    def train(self, args):
        best_val_loss = float('inf')
        train_losses, val_losses = [], []

        for epoch in range(1, self.epochs + 1):
            train_loss = self.train_epoch()
            val_loss = self.val_epoch()

            # logger
            train_losses.append(train_loss)
            val_losses.append(val_loss)

            if epoch % 5 == 0:
                print(f'Epoch {epoch}/{self.epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
            
            # save weight
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_weights(args.train.weight_save_path, args.env_name)
                #print('Saved best model weights')
        
        # plot
        title = f'emb_{args.env_name}'
        path = f'{args.train.plot_save_path}/{args.env_name}'
        if os.path.exists(path) is False:
            os.makedirs(path)
        Plot_losses(train_losses, val_losses, title = title, path = f'{path}/emb.png')


    def train_epoch(self):
        self.model.train()
        total_loss = 0.0
        for state, action in self.train_loader:
            state, action = state.to(self.device), action.to(self.device)
            loss = self.model(state)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
                
        return total_loss / len(self.train_loader.dataset)


    @torch.no_grad()
    def val_epoch(self):
        self.model.eval()
        total_loss = 0.0
        for state, action in self.val_loader:
            state, action = state.to(self.device), action.to(self.device)
            loss = self.model(state)
            total_loss += loss.item()
                
        return total_loss / len(self.val_loader.dataset)


    
    def save_weights(self, save_weight_path, env_name):
        path = f'{save_weight_path}/{env_name}'
        if os.path.exists(path) is False:
            os.makedirs(path)
        torch.save(self.model.state_dict(), f'{path}/emb.pth')