import os
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import random
import torch.nn.functional as F
import copy
from utils.loader import load_seed, load_device, load_data, load_model_params, load_model_optimizer, load_loss_fn, \
                         load_g3nn_model_optimizer, load_simple_loss_fn
from utils.logger import Logger, set_log, start_log, train_log

import scipy.sparse as sp
from datetime import datetime

class Trainer(object):
    def __init__(self, config):
        super(Trainer, self).__init__()

        self.config = config
        self.log_folder_name, self.log_dir = set_log(self.config)
        self.seed = load_seed(self.config.seed)
        self.device = load_device()[0]
        self.x, self.y, self.adj, self.train_mask, self.valid_mask, self.test_mask = load_data(self.config)
        self.y = torch.argmax(self.y, dim = 1)


    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'{ts}'
        print('\033[91m' + f'{self.ckpt}' + '\033[0m')

        self.params = load_model_params(self.config)
        self.model, self.generator, self.optimizer, self.scheduler = load_g3nn_model_optimizer(self.params, self.config.train, self.device)

        self.model = self.model.to(self.device)
        self.generator = self.generator.to(self.device)

        logger = Logger(str(os.path.join(self.log_dir, f'{self.ckpt}.log')), mode='a')
        logger.log(f'{self.ckpt}', verbose=False)
        start_log(logger, self.config)
        train_log(logger, self.config)


        def train_step():
            self.model.train()
            self.generator.train()
            self.optimizer.zero_grad()
            post_y_log_prob = self.model(self.x, self.adj)
            nll_generative = self.generator.nll_generative(self.x, self.y, self.adj, self.train_mask, post_y_log_prob)
            nll_discriminative = F.nll_loss(post_y_log_prob[self.train_mask],
                                            self.y[self.train_mask])
            loss = nll_generative + self.config.generative.lamda * nll_discriminative
            loss.backward()
            self.optimizer.step()


        def test():
            self.model.eval()
            logits = self.model(self.x, self.adj)
            accs = []
            for _, mask in enumerate([self.train_mask, self.valid_mask, self.test_mask]):
                pred = logits[mask].max(1)[1]
                acc = pred.eq(self.y[mask]).sum().item() / mask.sum().item()
                accs.append(acc)
            return accs



        best_val_accs = 0
        selected_accs = None
        #patience = args.patience
        for epoch in range(1, self.config.train.num_epochs):
            train_step()
            accs = test()
            if accs[1] > best_val_accs:
                best_val_accs = accs[1]
                selected_accs = accs

            logger.log(f'{epoch+1:03d} | val: {accs[1]:.3e} | test: {accs[2]:.3e}  | best val: {selected_accs[1]:.3e} | best test: {selected_accs[2]:.3e}', verbose=False)
            print(f'[Epoch {epoch+1:04d}] | val: {accs[1]:.3e} | test: {accs[2]:.3e}  | best val: {selected_accs[1]:.3e} | best test: {selected_accs[2]:.3e}', end = '\r')
        print()