import time
import torch
import numpy as np
import torch.nn as nn
from torch.optim import Adam
import os
from task.base_task import BaseTask
from task.utils import unsupervised_node_cls_evaluate
from task.utils import unsupervised_node_cls_train_v2, unsupervised_node_cls_evaluate_v2
from task.utils import unsupervised_node_cls_mini_batch_train_v2
from model.lrw import LearnableRandomWalk
import datetime

class UnsupervisedNodeClassification(BaseTask):
    def __init__(self, logger, args, dataset, model, normalize_times, lr, weight_decay, epochs, logepochs, early_stop, device, walk_time, train_loader=None,
                 loss_fn=nn.CrossEntropyLoss()):
        super(UnsupervisedNodeClassification, self).__init__()
        self.logger = logger
        self.normalize_times = normalize_times
        self.normalize_record = {"val_acc": [], "test_acc": []}
        self.task_level = "node"
        self.dataset = dataset
        self.labels = self.dataset.y
        self.args = args
        self.model_zoo = model
        if args.data_name == "arxivdir":
            self.model = model
        else:
            self.model = model.model_init()
        self.optimizer_1 = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.optimizer_2 = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.epochs = epochs
        self.logepochs = logepochs 
        self.loss_fn = loss_fn
        self.early_stop = early_stop
        self.device = device
        self.walk_time = walk_time
        self.batch_size = None
        self.train_loader = None
        self.acc_train = None
        if self.args.data_name in ("arxivdir"):
            self.batch_size = self.args.train_batch_size
            logger.info(f"Mini-batch training size: {self.batch_size}")
            self.train_loader = train_loader
        total_epochs_time = []
        total_time = []
        for i in range(self.normalize_times):
            begin_t = time.time()
            if i == 0:
                normalize_times_st = time.time()
            else: 
                self.model = self.model_zoo.model_init()
                self.optimizer_1 = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
                self.optimizer_2 = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
            self.acc, epochs_time = self.execute()
            total_epochs_time += epochs_time
            total_time.append(time.time() - begin_t)

        if self.normalize_times > 1:
            logger.info("Optimization Finished!")
            logger.info("Total training time is: {:.4f}s".format(time.time() - normalize_times_st))      
            logger.info("Mean Val ± Std Val: {}±{}, Mean Test ± Std Test: {}±{}".format(
                round(np.mean(self.normalize_record["val_acc"]), 4),
                round(np.std(self.normalize_record["val_acc"], ddof=1), 4),
                round(np.mean(self.normalize_record["test_acc"]), 4),
                round(np.std(self.normalize_record["test_acc"], ddof=1), 4)))
            logger.info("Mean Epoch ± Std Epoch: {:.4f}s±{:.4f}s, Mean Total ± Std Total: {:.4f}s±{:.4f}s".format(
                np.mean(total_epochs_time), np.std(total_epochs_time),
                np.mean(total_time), np.std(total_time)))

    def execute(self):
        pre_time_st = time.time()
        if self.args.data_name == "arxivdir":
            self.model.preprocess(self.dataset.adj, self.dataset.x)
        else:
            self.model.preprocess(self.dataset.adj, self.dataset.x)
        pre_time_ed = time.time()
        if self.normalize_times == 1:
            self.logger.info(f"Preprocessing done in {(pre_time_ed - pre_time_st):.4f}s")
        self.model = self.model.to(self.device)
        self.labels = self.labels.to(self.device)
        time_list = []
        total_loss = []
        stop = 0
        best_losses = 1
        for epoch in range(self.epochs):
            t = time.time()
            if stop > 10:
                self.logger.info("Early stop!")
                break
            if self.args.data_name in ("arxivdir"):
                losses_train = unsupervised_node_cls_mini_batch_train_v2(self.model , self.device, self.optimizer_1, epoch, self.dataset.train_idx, self.train_loader, self.walk_time)
                if epoch % 5 ==0:
                    model_save_path = "./model_files/lrw/"+self.dataset.name
                    if not os.path.exists(model_save_path):
                        os.mkdir(model_save_path)
                    model_save_path = model_save_path+"/"+self.args.model_name+".pth"
                    torch.save(self.model.state_dict(), model_save_path)
            else:
                losses_train = unsupervised_node_cls_train_v2(self.model, self.device, self.optimizer_1, epoch, self.dataset.train_idx, self.walk_time)
            losses_train = np.array(losses_train)
            total_loss.append(losses_train.mean())
            epoch_time = time.time() - t

            if epoch % 5 ==0:
                model_save_path = "./model_files/lrw/"+self.dataset.name
                if not os.path.exists(model_save_path):
                    os.makedirs(model_save_path)
                model_save_path = model_save_path+"/"+self.args.model_name+".pth"
                torch.save(self.model.state_dict(), model_save_path)
            if self.normalize_times == 1:
                self.logger.info("Epoch: {:03d}, loss_train: {:.4f}, time: {:.4f}s".format(epoch+1, losses_train.mean(), epoch_time))
            print("Epoch: {:03d}, loss_train: {:.4f}, time: {:.4f}s".format(epoch+1, losses_train.mean(), epoch_time))


            time_list.append(epoch_time)
            torch.cuda.empty_cache()
            if losses_train < best_losses-0.001:
                best_losses = losses_train
                stop = 0
            stop += 1
            
        total_loss = np.array(total_loss)
        print(f"The mean of total training loss is {total_loss.mean()}")
        print(f"The variance of totol training loss is {total_loss.var()}")        
        model_save_path = "./model_files/lrw/"+self.dataset.name
        if not os.path.exists(model_save_path):
            os.mkdir(model_save_path)
        model_save_path = model_save_path+"/"+self.args.model_name+".pth"
        torch.save(self.model.state_dict(), model_save_path)
        

        self.model  = LearnableRandomWalk(data_name=self.args.data_name, n=self.labels.size(0), hidden_att_dim=self.args.hidden_att_dim, 
                                feat_dim=self.dataset.num_features, hidden_dim=self.args.hidden_dim, output_dim=self.dataset.num_node_classes, 
                                cof=self.args.walk_cof, delta=self.args.walk_delta, walk_time=self.args.walk_time, 
                                walk_len=self.args.walk_len, dropout=self.args.dropout, task_level=self.task_level, y=self.labels)
        self.model.preprocess(self.dataset.adj, self.dataset.x)
        self.model.load_state_dict(torch.load(model_save_path))
        self.model = self.model.to(self.device)
        losses_val = unsupervised_node_cls_evaluate_v2(self.model , self.device, self.optimizer_1, self.dataset.val_idx, self.labels)
        losses_val = np.array(losses_val)
        print(f"The mean of validation loss is {losses_val.mean()}")
        print(f"The variance of validation loss is {losses_val.var()}")
        
        self.logger.info("=========Finish========")
        return total_loss.mean(), total_loss.var()

    def postprocess(self, x):
        self.model.eval()
        loss_train, acc_train, acc_val, acc_test = \
            unsupervised_node_cls_evaluate(self.model, x, self.dataset.train_idx, self.dataset.val_idx, 
                                           self.dataset.test_idx, self.loss_fn, self.labels, self.optimizer_2)
        return acc_val, acc_test
    

class UnsupervisedTestNodeClassification(BaseTask):
    def __init__(self, logger, dataset, model, normalize_times, lr, weight_decay, epochs, logepochs, early_stop, device, 
                 loss_fn=nn.CrossEntropyLoss()):
        super(UnsupervisedTestNodeClassification, self).__init__()
        self.logger = logger
        self.normalize_times = normalize_times
        self.normalize_record = {"val_acc": [], "test_acc": []}

        self.dataset = dataset
        self.labels = self.dataset.y

        self.model_zoo = model
        self.model = model.model_init()

        model_load_path = "./model_files/lrw/"+self.dataset.name + "/lrw.pth"
        self.model = torch.load(model_load_path)
        self.model.eval()
        self.logger.info("model load completed!")

        device = torch.device('cpu')
        now_time = datetime.datetime.now()
        with open("./log/"+model.args.model_name+"/"+self.dataset.name+"/labels_"+str(now_time.strftime('%Y-%m-%d %H-%M-%S'))+".csv", "w") as f:
            for i in range(self.labels.shape[0]):
                walk, rw = self.model.model_forward(device, i)
                walk = walk.numpy()
                logger.info(f"The node {i}'s random walks are {walk}")
                logger.info(f"The labels are {self.labels[walk].numpy()}")
                sentence = str(i)+","+','.join(map(str, self.labels[walk].numpy()))+"\n"
                f.write(sentence)
