import os
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import random
import torch.nn.functional as F
import copy
from utils.loader import load_seed, load_device, load_data, load_model_params, load_model_optimizer, load_loss_fn, \
                         load_lpa_model_optimizer, load_simple_loss_fn
from utils.logger import Logger, set_log, start_log, train_log

import scipy.sparse as sp
from datetime import datetime

class Trainer(object):
    def __init__(self, config):
        super(Trainer, self).__init__()

        self.config = config
        self.log_folder_name, self.log_dir = set_log(self.config)
        self.seed = load_seed(self.config.seed)
        self.device = load_device()[0]
        self.x, self.y, self.adj, self.train_mask, self.valid_mask, self.test_mask = load_data(self.config)
        self.y = torch.argmax(self.y, dim = 1)


    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'{ts}'
        print('\033[91m' + f'{self.ckpt}' + '\033[0m')

        self.params = load_model_params(self.config)
        self.params['num_edges'] = self.adj.shape[1]
        self.model, self.optimizer, self.scheduler = load_lpa_model_optimizer(self.params, self.config.train, self.device)
        self.model = self.model.to(self.device)

        logger = Logger(str(os.path.join(self.log_dir, f'{self.ckpt}.log')), mode='a')
        logger.log(f'{self.ckpt}', verbose=False)
        start_log(logger, self.config)
        train_log(logger, self.config)


        def train_step():
            self.model.train()
            self.optimizer.zero_grad()
            pred_y, pred_y_lp = self.model(self.x, self.y, self.adj, self.train_mask)

            loss = F.nll_loss(pred_y[self.train_mask], self.y[self.train_mask]) \
                                + self.config.lpa.lamda * F.nll_loss(pred_y_lp[self.train_mask], self.y[self.train_mask])
            loss.backward()
            self.optimizer.step()


        def test():
            self.model.eval()
            logits, _ = self.model(self.x, self.y, self.adj, self.train_mask)
            accs = []
            for _, mask in enumerate([self.train_mask, self.valid_mask, self.test_mask]):
                pred = logits[mask].max(1)[1]
                acc = pred.eq(self.y[mask]).sum().item() / mask.sum().item()
                accs.append(acc)
            return accs



        best_val_accs = 0
        selected_accs = None
        for epoch in range(1, self.config.train.num_epochs):
            train_step()
            accs = test()
            if accs[1] > best_val_accs:
                best_val_accs = accs[1]
                selected_accs = accs
            logger.log(f'{epoch+1:03d} | val: {accs[1]:.3e} | test: {accs[2]:.3e}  | best val: {selected_accs[1]:.3e} | best test: {selected_accs[2]:.3e}', verbose=False)
            print(f'[Epoch {epoch+1:04d}] | val: {accs[1]:.3e} | test: {accs[2]:.3e}  | best val: {selected_accs[1]:.3e} | best test: {selected_accs[2]:.3e}', end = '\r')
        print()