import os
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import random
import torch.nn.functional as F
from utils.loader import load_seed, load_device, load_data, load_model_params, load_model_optimizer, load_loss_fn, \
                         load_simple_model_optimizer, load_simple_loss_fn
from utils.logger import Logger, set_log, start_log, train_log

import scipy.sparse as sp
from datetime import datetime

class Trainer(object):
    def __init__(self, config):
        super(Trainer, self).__init__()

        self.config = config
        self.log_folder_name, self.log_dir = set_log(self.config)
        self.seed = load_seed(self.config.seed)
        self.device = load_device()
        self.x, self.y, self.adj, self.train_mask, self.valid_mask, self.test_mask = load_data(self.config)
        self.simple_losses = load_simple_loss_fn(self.config, self.device)
        
    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'{ts}'
        print('\033[91m' + f'{self.ckpt}' + '\033[0m')

        # Prepare model, optimizer, and logger
        self.params = load_model_params(self.config)
        self.simple_model, self.simple_optimizer, self.simple_scheduler = load_simple_model_optimizer(self.params, self.config.train, self.device)
        self.simple_loss_fn = self.simple_losses.loss_fn
        self.simple_estimator = self.simple_losses.estimate          

        
        logger = Logger(str(os.path.join(self.log_dir, f'{self.ckpt}.log')), mode='a')
        logger.log(f'{self.ckpt}', verbose=False)
        start_log(logger, self.config)
        train_log(logger, self.config)
        

        # Pre-train mean-field GNN
        best_valid, best_test = 0, 0
        for epoch in range(0,self.config.train.num_epochs):
            self.simple_model.train()
            self.simple_optimizer.zero_grad()
            
            loss_subject = (self.x, self.adj, self.y, self.train_mask)
            loss = self.simple_loss_fn(self.simple_model, *loss_subject)
            loss.backward()
            self.simple_optimizer.step()
            if self.config.train.lr_schedule:
                self.simple_scheduler.step()
            
            # Evaluate GNN
            self.simple_model.eval()
            start=datetime.now()
            y_est = self.simple_estimator(self.simple_model, self.x, self.adj, self.y, self.train_mask)
            pred = torch.argmax(y_est, dim = -1)
            label = torch.argmax(self.y, dim = -1)
            valid_acc = torch.mean((pred==label)[self.valid_mask].float()).item()
            test_acc = torch.mean((pred==label)[self.test_mask].float()).item()
            
            if valid_acc > best_valid:
                best_valid = valid_acc
                best_test = valid_acc

            # Log intermediate performance
            logger.log(f'{epoch+1:03d} | val: {valid_acc:.3e} | test: {test_acc:.3e}  | best val: {best_valid:.3e} | best test: {best_test:.3e}', verbose=False)         
            print(f'[Epoch {epoch+1:05d}] | val: {valid_acc:.3e} | test: {test_acc:.3e}  | best val: {best_valid:.3e} | best test: {best_test:.3e}', end = '\r')
    

        print(' ')
