#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import math
import os
import random
import sys
import time
import dgl
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl import function as fn
from dgl.data import (
    AmazonCoBuyComputerDataset,
    AmazonCoBuyPhotoDataset,
    CiteseerGraphDataset,
    CoauthorCSDataset,
    CoraFullDataset,
    CoraGraphDataset,
    PubmedGraphDataset,
    RedditDataset,
)
from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from tqdm import tqdm, trange

from models_label_trick import GAT, GCN, MLP, LabelPropagation
import utils
import embedder

class Label_Trick():
    def __init__(self, args):
        self.args = args
    
    def training(self):
        file = utils.set_filename(self.args)
        logger = utils.setup_logger('./', '-', file)

        seed_result = {}
        seed_result['acc'] = []
        seed_result['macro_F'] = []
        
        for seed in trange(0, 0+self.args.n_runs):
            print(f'============== seed:{seed} ==============')
            utils.seed_everything(seed)
            print('seed:', seed, file)
            self.args.seed = seed
            self = embedder(self.args, seed)
            self.x = torch.nan_to_num(self.x, 0)

            # Main training
            model = modeler(self.args).to(self.args.device)
            optimizer = optim.Adam(model.parameters(), lr=self.args.lr)

            acc_vals = []
            test_results = []
            best_metric = 0

            for epoch in range(0, self.args.epochs):
                model.train()

                optimizer.zero_grad()
                loss = model(self.x, self.labels, self.train_mask)

                loss.backward()
                optimizer.step()

                # if self.train_loader:
                #     total_loss = 0
                #     idx = 0
                #     for batch_size, n_id, adjs in self.train_loader:
                #         optimizer.zero_grad()
                #         idx += 1
                #         # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
                #         adjs = [adj.to(self.args.device) for adj in adjs]
                #         x_batch = self.x[n_id]
                #         y_pred = model(x_batch, adjs=adjs, full_batch=False)
                #         y_true = self.labels[n_id[:batch_size]].squeeze()
                #         loss = F.cross_entropy(y_pred[self.train_mask], y_true[self.train_mask])

                #         loss.backward()
                #         optimizer.step()
                #         total_loss += loss.item()

                #         # logger.debug(f"Batch loss: {loss.item():.2f}")
                # else:
                #     optimizer.zero_grad()
                #     loss = model(self.x, self.labels, self.train_mask)

                #     loss.backward()
                #     optimizer.step()

                # Valid
                model.eval()
                output = model.classifier(self.x)
                
                # if self.inference_loader:
                #     total_edges = 0
                #     for i in range(self.num_layers):
                #         xs = []
                #         for batch_size, n_id, adj in self.inference_loader:
                #             edge_index, _, size = adj.to(self.args.device)
                #             total_edges += edge_index.size(1)
                #             x = self.x[n_id].to(self.args.device)
                #             x_target = x[: size[1]]
                #             x = self.convs[i]((x, x_target), edge_index)
                #             if i != self.num_layers - 1:
                #                 x = F.relu(x)
                #             xs.append(x.cpu())

                #         output = torch.cat(xs, dim=0)

                # else:
                #     output = model.classifier(self.x)
                
                acc_val, macro_F_val = utils.performance(output[self.val_mask], self.labels[self.val_mask], pre='valid', evaluator=self.evaluator)
                acc_vals.append(acc_val)
                max_idx = acc_vals.index(max(acc_vals))

                if best_metric <= acc_val:
                    best_metric = acc_val
                    best_output = output[:]

                # Test
                acc_test, macro_F_test = utils.performance(output[self.test_mask], self.labels[self.test_mask], pre='test', evaluator=self.evaluator)

                test_results.append([acc_test, macro_F_test])
                best_test_result = test_results[max_idx]

                if epoch % self.args.print_result == 0:
                    st = "[seed {}][{}-{}][{}-{}][Epoch {}]".format(seed, self.args.dataset, self.args.missing_rate, self.args.embedder, self.args.filling_method, epoch)
                    st += "[Val] ACC: {:.2f}, Macro-F1: {:.2f}|| ".format(acc_val, macro_F_val)
                    st += "[Test] ACC: {:.2f}, Macro-F1: {:.2f}\n".format(acc_test, macro_F_test)
                    st += "  [*Best Test Result*][Epoch {}] ACC: {:.2f}, Macro-F1: {:.2f}".format(max_idx, best_test_result[0], best_test_result[1])
                    print(st)
                      
                if (epoch - max_idx > self.args.patience) or (epoch+1 == self.args.epochs):
                    if epoch - max_idx > self.args.patience:
                        print("Early stop")
                    output = best_output
                    best_test_result[0], best_test_result[1] = utils.performance(output[self.test_mask], self.labels[self.test_mask], pre='test', evaluator=self.evaluator)
                    print("[Best Test Result] ACC: {:.2f}, Macro-F1: {:.2f}".format(best_test_result[0], best_test_result[1]))
                    torch.cuda.empty_cache()
                    break

            seed_result['acc'].append(float(best_test_result[0]))
            seed_result['macro_F'].append(float(best_test_result[1]))

        acc = seed_result['acc']
        f1 = seed_result['macro_F']

        print('[Averaged result] ACC: {:.2f}+{:.2f}, Macro-F: {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))
        print('{:.2f}+{:.2f} {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))

        logger.info('')
        logger.info(datetime.datetime.now())
        logger.info(file)
        logger.info(f'----------- missing rate: {self.args.missing_rate} -----------')
        logger.info('{:.2f}+{:.2f} {:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc), np.mean(f1), np.std(f1)))
        logger.info('{:.2f}+{:.2f}'.format(np.mean(acc), np.std(acc)))
        logger.info('{:.2f}+{:.2f}'.format(np.mean(f1), np.std(f1)))
        logger.info(self.args)
        logger.info(f'=================================')


        # print(self.args)


class modeler(nn.Module):
    def __init__(self, args):
        super(modeler, self).__init__()
        self.args = args
    
        classifier = MLP(num_features=args.n_feat, hidden_dim=args.n_hid, num_classes=args.n_class, num_layers=args.n_layer, dropout=self.args.dropout, batch_norm=self.args.batch_norm)
        
        self.classifier = classifier

    def forward(self, x, labels, idx_train):
        output = self.classifier(x)
        if 'OGBN' in self.args.dataset: # arxiv only or producs as well?
            labels = labels.squeeze(1)
        loss_nodeclassification = F.cross_entropy(output[idx_train], labels[idx_train])

        return loss_nodeclassification