import os
import time
from tqdm import tqdm
import numpy as np
import torch

from utils.loader import load_seed, load_device, load_sde, load_prop_data
from utils.logger import Logger, set_log
from utils.classifier_utils import load_classifier_params, load_classifier_batch, load_classifier_optimizer, \
                                   load_classifier_loss_fn, start_log, train_log
from utils.graph_utils import gen_noise, mask_x, mask_adjs


class Trainer_classifier(object):
    def __init__(self, config):
        super(Trainer_classifier, self).__init__()
        self.config = config
        self.config.module = ''
        self.log_folder_name, self.log_dir, self.ckpt_dir = set_log(self.config)

        self.device = load_device(self.config.gpu)

        self.train_loader, self.test_loader = load_prop_data(self.config)
        self.params = load_classifier_params(self.config)

        load_seed(self.config.seed)
    
    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'C-{ts}'
        print('\033[91m' + f'{self.ckpt}' + '\033[0m')

        self.model, self.optimizer, self.scheduler = load_classifier_optimizer(self.params, self.config.train, self.device)
        self.eps = self.config.train.eps
        self.sde_x = load_sde(self.config.sde.x)
        self.sde_adj = load_sde(self.config.sde.adj)

        logger = Logger(str(os.path.join(self.log_dir, f'{self.ckpt}_{self.config.train.prop}.log')), mode='a')
        start_log(logger, self.config)
        train_log(logger, self.config)

        logger.log(str(self.model))
        logger.log('-'*100)

        self.loss_fn = load_classifier_loss_fn(self.config)

        for epoch in range(self.config.train.num_epochs):
            self.trainloss = []
            self.testloss = []
            self.train_s_corr = []
            self.train_p_corr = []
            self.test_s_corr = []
            self.test_p_corr = []
            t_start = time.time()

            self.model.train()
            for _, train_b in enumerate(tqdm(self.train_loader, desc=f'[Epoch {epoch+1}]')):
                x, adj, labels = load_classifier_batch(train_b, self.device)

                self.model.train()
                self.optimizer.zero_grad()
                loss, s_corr, p_corr = self.loss_fn(self.model, x, adj, labels)
                loss.backward()

                self.trainloss.append(loss.item())
                self.train_s_corr.append(s_corr)
                self.train_p_corr.append(p_corr)

                self.optimizer.step()
                
            if self.config.train.lr_schedule:
                self.scheduler.step()

            self.model.eval()
            for _, test_b in enumerate(self.test_loader):
                x, adj, labels = load_classifier_batch(test_b, self.device)

                with torch.no_grad():
                    loss, s_corr, p_corr = self.loss_fn(self.model, x, adj, labels)
                    self.testloss.append(loss.item())
                    self.test_s_corr.append(s_corr)
                    self.test_p_corr.append(p_corr)

            logger.log(f'Epoch: {epoch+1:03d} | {time.time()-t_start:.2f}s | '
                       f'TRAIN loss: {np.mean(self.trainloss):.4e} | '
                       f'S corr: {np.mean(self.train_s_corr):.4f} | '
                       f'P corr: {np.mean(self.train_p_corr):.4f} | '
                       f'TEST loss: {np.mean(self.testloss):.4e} | '
                       f'S corr: {np.mean(self.test_s_corr):.4f} | '
                       f'P corr: {np.mean(self.test_p_corr):.4f}', verbose=False)
            if epoch % self.config.train.save_interval == self.config.train.save_interval - 1:
                save_name = f'_epoch{epoch+1}'

                torch.save({
                    'model_config': self.config,
                    'params' : self.params,
                    'state_dict': self.model.state_dict()},
                    f'./checkpoints/{self.config.data.data}/{self.ckpt + save_name}.pth')
