import os
import time
from tqdm import tqdm
import numpy as np
import torch

from utils.loader import load_model_optimizer, load_data, load_low_data, load_loss_fn, load_batch, load_seed, \
                         load_device, load_model_params
from utils.logger import Logger, set_log, start_log, train_log


class Trainer(object):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.module = self.config.module
        self.log_folder_name, self.log_dir, self.ckpt_dir = set_log(self.config)

        self.seed = load_seed(self.config.seed)
        self.device = load_device(self.config.gpu)
        if self.config.train.low_protein:
            self.train_loader = load_low_data(self.config)
        else:
            self.train_loader, self.test_loader = load_data(self.config)

        self.params_x, self.params_adj = load_model_params(self.config)
    
    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'{self.module}-{ts}'
        print(f'\033[91m{self.ckpt}\033[0m')

        ### Load Model & Optimizer
        self.model_x, self.optimizer_x, self.scheduler_x = load_model_optimizer(self.params_x, self.config.train, self.device)
        self.model_adj, self.optimizer_adj, self.scheduler_adj = load_model_optimizer(self.params_adj, self.config.train, self.device)

        if self.config.train.low_protein:
            log_name = f'{self.ckpt}_{self.config.train.low_protein}_qed_sa_low.log'
        else:
            log_name = f'{self.ckpt}.log'
        logger = Logger(str(os.path.join(self.log_dir, log_name)), mode='a')
        start_log(logger, self.config)
        logger.log(str(vars(self.config)))
        logger.log('-'*100)

        logger.log(str(self.model_x))
        logger.log(str(self.model_adj))
        logger.log('-'*100)

        train_log(logger, self.config)

        self.loss_fn = load_loss_fn(self.config)

        for epoch in range(self.config.train.num_epochs):
            self.train_x = []
            self.train_adj = []
            self.test_x = []
            self.test_adj = []
            t_start = time.time()

            self.model_x.train()
            self.model_adj.train()
            for i, train_b in enumerate(tqdm(self.train_loader, desc=f'[Epoch {epoch+1}]')):
                self.optimizer_x.zero_grad()
                self.optimizer_adj.zero_grad()
                x, adj = load_batch(train_b, self.device)
                
                loss_x, loss_adj = self.loss_fn(self.model_x, self.model_adj, x, adj)
                
                loss_x.backward()
                loss_adj.backward()

                torch.nn.utils.clip_grad_norm_(self.model_x.parameters(), self.config.train.grad_norm)
                torch.nn.utils.clip_grad_norm_(self.model_adj.parameters(), self.config.train.grad_norm)

                self.optimizer_x.step()
                self.optimizer_adj.step()

                self.train_x.append(loss_x.item())
                self.train_adj.append(loss_adj.item())
            
            if self.config.train.lr_schedule:
                self.scheduler_x.step()
                self.scheduler_adj.step()

            mean_train_x = np.mean(self.train_x)
            mean_train_adj = np.mean(self.train_adj)

            self.model_x.eval()
            self.model_adj.eval()

            if not self.config.train.low_protein:
                for _, test_b in enumerate(self.test_loader):
                    x, adj = load_batch(test_b, self.device)

                    with torch.no_grad():
                        loss_x, loss_adj = self.loss_fn(self.model_x, self.model_adj, x, adj)
                        self.test_x.append(loss_x.item())
                        self.test_adj.append(loss_adj.item())
                        
                mean_test_x = np.mean(self.test_x)
                mean_test_adj = np.mean(self.test_adj)

                logger.log(f'Epoch: {epoch+1:03d} | {time.time()-t_start:.2f}s | '
                           f'test x: {mean_test_x:.3e} | test adj: {mean_test_adj:.3e} | '
                           f'train x: {mean_train_x:.3e} | train adj: {mean_train_adj:.3e}', verbose=False)
            else:
                logger.log(f'Epoch: {epoch+1:03d} | {time.time()-t_start:.2f}s | '
                           f'train x: {mean_train_x:.3e} | train adj: {mean_train_adj:.3e}', verbose=False)

            if epoch % self.config.train.save_interval == self.config.train.save_interval-1:
                torch.save({
                'model_config': self.config,
                'params_x' : self.params_x,
                'params_adj' : self.params_adj,
                'x_state_dict': self.model_x.state_dict(),
                'adj_state_dict': self.model_adj.state_dict(),
                }, f'./checkpoints/{self.config.data.data}/{self.ckpt}_epoch{epoch+1}.pth')
