import os
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import random
import torch.nn.functional as F
import copy
from utils.loader import load_seed, load_device, load_data, load_model_params, load_model_optimizer, load_loss_fn, \
                         load_clgnn_model_optimizer, load_simple_loss_fn
from utils.logger import Logger, set_log, start_log, train_log

import scipy.sparse as sp
from datetime import datetime

class Trainer(object):
    def __init__(self, config):
        super(Trainer, self).__init__()

        self.config = config
        self.log_folder_name, self.log_dir = set_log(self.config)
        self.seed = load_seed(self.config.seed)
        self.device = load_device()[0]
        self.x, self.y, self.adj, self.train_mask, self.valid_mask, self.test_mask = load_data(self.config)



    def train_batch(self, epoch, idx_opt_new, idx_true_new, pre_train=True, train_prediction=None, test_prediction=None):
        self.model.train()
        self.optimizer.zero_grad()
        idx_opt_new = idx_opt_new.to(self.device)
        idx_true_new = idx_true_new.to(self.device)

        output = self.model.forward(self.features, self.labels, train_prediction, self.adj, idx_true_new, self.config.collective.rand_number, pre_train)
        loss_train = F.nll_loss(output[idx_opt_new], self.labels[idx_opt_new])

        loss_train.backward()
        self.optimizer.step()

        self.model.eval()

        output_infer = self.model.forward(self.features, self.labels, test_prediction, self.adj, idx_true_new, self.config.collective.rand_number, pre_train)
        max_output_infer = torch.argmax(output_infer, dim = 1)
        acc_train = torch.mean((max_output_infer[idx_opt_new]==self.labels[idx_opt_new]).float()).item()
        acc_val = torch.mean((max_output_infer[self.idx_val]==self.labels[self.idx_val]).float()).item()
        acc_test = torch.mean((max_output_infer[self.idx_test]==self.labels[self.idx_test]).float()).item()

        return acc_train, acc_val, acc_test, (output/self.config.collective.temp).data.exp(), (output_infer/self.config.collective.temp).data.exp()


    def train_k_rounds(self, idx_obs_all, idx_train_all):
        best_test, best_val = 0, 0
        train_prediction = test_prediction = None

        for k in range(0,self.config.collective.iteration):
            t_total = time.time()
            if k > 0:
                train_prediction = best_train_output
                test_prediction = best_test_output
            not_use_pred = (k == 0) or (False)
            if k > 0:
                rr = np.random.permutation(len(idx_obs_all))
                idx_obs_all = idx_obs_all[rr]
                idx_train_all = idx_train_all[rr]

            for e, (idx_obs, idx_train) in enumerate(zip(idx_obs_all, idx_train_all)):
                for epoch_num in range(self.config.train.num_epochs):
                    acc_train, acc_val, acc_test, train_output_prob, test_output_prob = self.train_batch(epoch_num, idx_train, idx_obs,
                                                                                pre_train=(k==0),
                                                                                train_prediction=train_prediction,
                                                                                test_prediction=test_prediction)

                    print(f'[Iteration {k+1:02d}] | [Epoch {self.config.train.num_epochs*e+epoch_num+1:05d}] | val: {acc_val:.3e} | test: {acc_test:.3e} | best val: {best_val:.3e} | best test: {best_test:.3e}', end = '\r')
                self.logger.log(f'Iteration {k+1:02d} | {self.config.train.num_epochs*e+epoch_num+1:05d} | val: {acc_val:.3e} | test: {acc_test:.3e} | best val: {best_val:.3e} | best test: {best_test:.3e}', verbose=False)

                if acc_val > best_val:
                    best_val, best_test = acc_val, acc_test
                    best_train_output, best_test_output = train_output_prob, test_output_prob

        del self.model, optimizer
        torch.cuda.empty_cache()
        return best_test.item(), best_val.item()



    def train(self, ts):
        self.config.exp_name = ts
        self.ckpt = f'{ts}'
        print('\033[91m' + f'{self.ckpt}' + '\033[0m')

        # Prepare self.model, optimizer, and logger
        self.params = load_model_params(self.config)
        self.model, self.optimizer, self.scheduler = load_clgnn_model_optimizer(self.params, self.config.train, self.device)
        self.model = self.model.to(self.device)

        self.logger = Logger(str(os.path.join(self.log_dir, f'{self.ckpt}.log')), mode='a')
        self.logger.log(f'{self.ckpt}', verbose=False)
        start_log(self.logger, self.config)
        train_log(self.logger, self.config)


        # Prepare data
        self.adj, self.features, self.labels, idx_train_full, idx_val, idx_test  =  self.adj, self.x, self.y, self.train_mask, self.valid_mask, self.test_mask
        self.labels = torch.argmax(self.labels, dim= 1).to(self.device)

        idx_train_full = torch.where(idx_train_full == True)[0].to(self.device)
        self.idx_val = torch.where(idx_val == True)[0].to(self.device)
        self.idx_test = torch.where(idx_test == True)[0].to(self.device)
        self.idx_obs_infer = idx_train_full

        opt_ratio = 0.7
        idx_train_obs_train = idx_train_full
        idx_obs_all, idx_train_all = [], []
        for k in range(self.config.collective.batch):
            r = torch.randperm(idx_train_obs_train.shape[0])
            idx_train_obs_train = idx_train_obs_train[r]
            idx_obs_all.append(idx_train_obs_train[int(idx_train_obs_train.shape[0] * opt_ratio):])
            idx_train_all.append(idx_train_obs_train[:int(idx_train_obs_train.shape[0] * opt_ratio)])
        idx_obs_all = torch.stack(idx_obs_all, dim=0)
        idx_train_all = torch.stack(idx_train_all, dim=0)


        idx_train = idx_train_full
        acc, valid = self.train_k_rounds(idx_obs_all=idx_obs_all, idx_train_all=idx_train_all)
        print()