import argparse
from dataset_loader import DataLoader
from utils import random_planetoid_splits
from models import *
import torch
import torch.nn  as nn
import torch.nn.functional as F
from tqdm import tqdm
import random
import seaborn as sns
import numpy as np
import time
from relabel import relabel, relabel_with_all_data, relabel_with_pseudo_label, label_smoothing, monophily_uniform_smoothing, monophily_uniform_smoothing_with_pseudo_label
from relabel import monophily_uniform_smoothing_with_soft_pseudo_label, relabel_with_marginal_distribution, relabel_with_soft_pseudo_label, monophily_smoothing_with_nodewise_prior
from relabel import monophily_smoothing_with_nodewise_prior_with_pseudo_label, monophily_uniform_smoothing_threshold_alpha, monophily_uniform_smoothing_threshold_alpha_with_pseudo_label
from relabel import monophily_nodewise, monophily_nodewise_with_pseudo_label, monophily_uniform_smoothing_with_likelihood_alpha, monophily_uniform_smoothing_with_likelihood_alpha_with_pseudo_label
from torch_geometric.utils import one_hot
import wandb
from copy import deepcopy
from utils import plot_tsne
import os
import shutil

def get_optimizer(args, model):
    if args.net=='GPRGNN':
        optimizer = torch.optim.Adam([{ 'params': model.lin1.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.prop1.parameters(), 'weight_decay': 0.00, 'lr': args.lr}])

    elif args.net =='BernNet':
        optimizer = torch.optim.Adam([{'params': model.lin1.parameters(),'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.prop1.parameters(), 'weight_decay': 0.0, 'lr': args.Bern_lr}])
    else:
        optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)

    return optimizer

def get_soft_target(models, data, args):
    with torch.no_grad():
        logits=0
        for model in models:
            model.eval()
            logits += model(data)
        logits /= len(models)
        soft_target = F.softmax(logits/args.temperature, dim=1)
    
    return soft_target

def RunExp(args, dataset, data, Net, percls_trn, val_lb, num_run):
    def train(model, optimizer, data, dprate, soft_targets=None):
        model.train()
        optimizer.zero_grad()
        out = model(data)[data.train_mask]

        if soft_targets is not None:
            soft_prob = nn.functional.log_softmax(out / args.temperature, dim=-1)
            soft_targets_loss = torch.sum(soft_targets[data.train_mask] * (soft_targets[data.train_mask].log() - soft_prob)) / soft_prob.size()[0] * (args.temperature**2)
        out = F.log_softmax(out, dim=1)
        nll = F.nll_loss(out, data.y[data.train_mask])
        label_loss = nll
        loss = label_loss if soft_targets is None else (1-args.soft_label_ratio)*label_loss + args.soft_label_ratio*soft_targets_loss

        loss.backward()
        optimizer.step()
        del out

    def test(model, data):
        model.eval()
        logits, accs, losses, preds = model(data), [], [], []
        for split_type, mask in data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
            out = model(data)[mask]

            out = F.log_softmax(out, dim=1)
            loss = F.nll_loss(out, data.y[mask])

            preds.append(pred.detach().cpu())
            accs.append(acc)
            losses.append(loss.detach().cpu())
        return accs, preds, losses
    
    def ensemble_test(models, data):
        logits, accs, losses, preds = 0, [], [], []
        for model in models:
            model.eval()
            logits += model(data)
        logits /= len(models)
        for split_type, mask in data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
            out = logits[mask]

            out = F.log_softmax(out, dim=1)
            loss = F.nll_loss(out, data.y[mask])
            preds.append(pred.detach().cpu())
            accs.append(acc)
            losses.append(loss.detach().cpu())
        return accs, preds, losses
        

    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')
    tmp_nets = nn.ModuleList([Net(dataset, args) for _ in range(args.model_num)])

    #randomly split dataset
    permute_masks = random_planetoid_splits
    data = permute_masks(data, dataset.num_classes, percls_trn, val_lb,args.seed)
    models, data = nn.ModuleList([tmp_net.to(device) for tmp_net in tmp_nets]), data.to(device)

    backup_y = data.y
    # relabel_data(args, data, dataset)
    optimizers = [get_optimizer(args, model) for model in models]

    teacher_best_val_acc = teacher_test_acc = 0
    teacher_best_val_loss = float('inf')
    teacher_val_loss_history = []
    teacher_val_acc_history = []

    time_run=[]
    for model_idx in range(args.model_num):
        for epoch in range(args.epochs):
            t_st=time.time()
            train(models[model_idx], optimizers[model_idx], data, args.dprate)
            time_epoch=time.time()-t_st  # each epoch train times
            time_run.append(time_epoch)

            [train_acc, val_acc, tmp_test_acc], preds, [
                train_loss, val_loss, tmp_test_loss] = test(models[model_idx], data)

            if val_loss < teacher_best_val_loss:
                teacher_best_val_acc = val_acc
                teacher_best_val_loss = val_loss
                teacher_test_acc = tmp_test_acc
                if args.net =='BernNet':
                    TEST = tmp_nets[model_idx].prop1.temp.clone()
                    theta = TEST.detach().cpu()
                    theta = torch.relu(theta).numpy()
                else:
                    theta = args.alpha

            if epoch >= 0:
                teacher_val_loss_history.append(val_loss)
                teacher_val_acc_history.append(val_acc)
                if args.early_stopping > 0 and epoch > args.early_stopping:
                    tmp = torch.tensor(
                        teacher_val_loss_history[-(args.early_stopping + 1):-1])
                    if val_loss > tmp.mean().item():
                        #print('The sum of epochs:',epoch)
                        break

    [train_acc, val_acc, tmp_test_acc], preds, [train_loss, val_loss, tmp_test_loss] = ensemble_test(models, data)
    if args.wandb:
        wandb.log({'teacher_val_acc': teacher_best_val_acc, 'teacher_test_acc': teacher_test_acc})
    print(f"teacher test: {teacher_test_acc:.4f}, teacher val: {teacher_best_val_acc:.4f}, teacher_val_loss: {teacher_best_val_loss:.4f}")

    soft_target = get_soft_target(models, data, args)
    # new_label = args.soft_label_ratio * soft_label + (1-args.soft_label_ratio) * one_hot(data.y)
    del models
    
    # data.y = one_hot(data.y, dataset.num_classes)
    # data.y[data.train_mask] = new_label[data.train_mask]

    tmp_net = Net(dataset, args)
    model = tmp_net.to(device)
    optimizer = get_optimizer(args, model)

    best_val_acc = test_acc = 0
    best_val_loss = float('inf')
    val_loss_history = []
    val_acc_history = []
    for epoch in range(args.epochs):
        t_st=time.time()
        train(model, optimizer, data, args.dprate, soft_target)
        time_epoch=time.time()-t_st  # each epoch train times
        time_run.append(time_epoch)

        [train_acc, val_acc, tmp_test_acc], preds, [
            train_loss, val_loss, tmp_test_loss] = test(model, data)
        
        if args.wandb:
            wandb.log({'train_acc': train_acc, 'val_acc': val_acc, 'test_acc': tmp_test_acc, 'train_loss': train_loss, 'val_loss':val_loss, 
                'test_loss': tmp_test_loss, 'epoch':epoch})

        if val_loss < best_val_loss:
            best_val_acc = val_acc
            best_val_loss = val_loss
            test_acc = tmp_test_acc
            best_model_state_dict = deepcopy(model.state_dict())
            if args.net =='BernNet':
                TEST = tmp_net.prop1.temp.clone()
                theta = TEST.detach().cpu()
                theta = torch.relu(theta).numpy()
            else:
                theta = args.alpha

        if epoch >= 0:
            val_loss_history.append(val_loss)
            val_acc_history.append(val_acc)
            if args.early_stopping > 0 and epoch > args.early_stopping:
                tmp = torch.tensor(
                    val_loss_history[-(args.early_stopping + 1):-1])
                if val_loss > tmp.mean().item():
                    #print('The sum of epochs:',epoch)
                    break

    data.y = backup_y
    print(f"test: {test_acc:.4f}, val: {best_val_acc:.4f}, val_loss: {best_val_loss:.4f}")
    if args.wandb:
        wandb.log({'stud_test': test_acc, 'stud_val': best_val_acc, 'stud_val_loss':best_val_loss})

    if args.tsne:
        print("TSNE plot")
        model.load_state_dict(best_model_state_dict)
        logit = model(data)
        train_mask_idxs = data.train_mask.nonzero().squeeze()
        val_mask_idxs = data.val_mask.nonzero().squeeze()
        test_mask_idxs = data.test_mask.nonzero().squeeze()

        train_num_sample = 1000 if train_mask_idxs.shape[0] > 1000 else train_mask_idxs.shape[0]
        val_num_sample = 1000 if val_mask_idxs.shape[0] > 1000 else val_mask_idxs.shape[0]
        test_num_sample = 1000 if test_mask_idxs.shape[0] > 1000 else test_mask_idxs.shape[0]

        #sampled_train_mask = train_mask_idxs[torch.randperm(train_mask_idxs.shape[0])[:train_num_sample]]
        #sampled_val_mask = val_mask_idxs[torch.randperm(val_mask_idxs.shape[0])[:val_num_sample]]
        #sampled_test_mask = test_mask_idxs[torch.randperm(test_mask_idxs.shape[0])[:test_num_sample]]
        sampled_train_mask = train_mask_idxs[:train_num_sample]
        sampled_val_mask = val_mask_idxs[:val_num_sample]
        sampled_test_mask = test_mask_idxs[:test_num_sample]

        plot_tsne(logit[sampled_train_mask].cpu().detach(), backup_y[sampled_train_mask].cpu().detach(), f'fig/{args.tsne_path}_{num_run}_train.png')
        plot_tsne(logit[sampled_val_mask].cpu().detach(), backup_y[sampled_val_mask].cpu().detach(), f'fig/{args.tsne_path}_{num_run}_val.png')
        plot_tsne(logit[sampled_test_mask].cpu().detach(), backup_y[sampled_test_mask].cpu().detach(), f'fig/{args.tsne_path}_{num_run}_test.png')

    return test_acc, best_val_acc, theta, time_run, teacher_test_acc, teacher_best_val_acc, best_model_state_dict


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=2108550661, help='seeds for random splits.')
    parser.add_argument('--epochs', type=int, default=1000, help='max epochs.')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate.')       
    parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay.')  
    parser.add_argument('--early_stopping', type=int, default=200, help='early stopping.')
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')

    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')
    parser.add_argument('--K', type=int, default=10, help='propagation steps for APPNP/ChebNet/GPRGNN.')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha for APPNP/GPRGNN.')
    parser.add_argument('--dprate', type=float, default=0.5, help='dropout for propagation layer.')
    parser.add_argument('--Init', type=str,choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'], default='PPR', help='initialization for GPRGNN.')
    parser.add_argument('--heads', default=8, type=int, help='attention heads for GAT.')
    parser.add_argument('--output_heads', default=1, type=int, help='output_heads for GAT.')

    parser.add_argument('--dataset', type=str, choices=['Cora','Citeseer','Pubmed','Computers','Photo','Chameleon','Squirrel','Actor','Texas','Cornell'],
                        default='Cornell')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, choices=['GCN', 'GAT', 'APPNP', 'ChebNet', 'GPRGNN','BernNet','MLP'], default='GCN')
    parser.add_argument('--Bern_lr', type=float, default=0.01, help='learning rate for BernNet propagation layer.')

    # Arguments for relabeling
    parser.add_argument('--labeling_method', type=str, default='monophily_uniform', choices=['monophily_uniform', 'monophily_marginal', 
                            'monophily_nodewise_prior', 'monophily_uniform_threshold', 'monophily_nodewise', 'monophily_likelihood_alpha'])
    parser.add_argument('--soft_label_ratio', type=float, default=1, help='interpolation ratio for soft label')
    parser.add_argument('--smoothing_ratio', type=float, default=0.1, help='interpolation ratio for uniform soft label')
    parser.add_argument('--degree_cutoff', type=int, default=1, help='nodes with a degree lower than the cutoff will be disregarded')
    parser.add_argument('--temperature', type=float, default=1.0, help='temperature for probability calculation')
    parser.add_argument('--pseudo_label_type', type=str, default='hard', choices=['soft', 'hard'])
    parser.add_argument('--pseudo_temperature', type=float, default=1.0, help='temperature for probability calculation')
    parser.add_argument('--pseudo_smoothing_ratio', type=float, default=0.1, help='interpolation ratio for uniform soft label')
    parser.add_argument('--prior_noise', type=float, default=0.1)
    parser.add_argument('--tsne', type=int, default=0)
    parser.add_argument('--tsne_path', type=str, default='tsne')
    parser.add_argument('--threshold_gap', type=float, default=1e-6)

    parser.add_argument('--wandb', type=int, default=1)
    parser.add_argument('--model_num', type=int, default=1)

    # Will be removed
    parser.add_argument('--num_hop', type=int, default=1)
    parser.add_argument('--confident_labeling', type=int, default=0)
    parser.add_argument('--adaptive_alpha', type=int, default=0)
    parser.add_argument('--adaptive_alpha_a', type=float, default=0)
    parser.add_argument('--confident_pseudo_labeling', type=int, default=0)
    parser.add_argument('--confident_quantile', type=float, default=0)

    args = parser.parse_args()
    args.confident_labeling = bool(args.confident_labeling)
    args.adaptive_alpha = bool(args.adaptive_alpha)
    args.confident_pseudo_labeling = bool(args.confident_pseudo_labeling)
    args.tsne = bool(args.tsne)
    args.wandb = bool(args.wandb)

    if args.adaptive_alpha:
        assert args.degree_cutoff==1
    else:
        assert args.adaptive_alpha_a==0

    if args.wandb:
        wandb.init(entity="jsheo12304",project='relabel',config=args)

    #10 fixed seeds for splits
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]

    print(args)
    print("---------------------------------------------")

    gnn_name = args.net
    if gnn_name == 'GCN':
        Net = GCN_Net
    elif gnn_name == 'GAT':
        Net = GAT_Net
    elif gnn_name == 'APPNP':
        Net = APPNP_Net
    elif gnn_name == 'ChebNet':
        Net = ChebNet
    elif gnn_name == 'GPRGNN':
        Net = GPRGNN
    elif gnn_name == 'BernNet':
        Net = BernNet
    elif gnn_name =='MLP':
        Net = MLP

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    #print(data)

    percls_trn = int(round(args.train_rate*len(data.y)/dataset.num_classes))
    val_lb = int(round(args.val_rate*len(data.y)))

    results = []
    teacher_results = []
    time_results=[]
    best_model_state_dicts = []
    for RP in tqdm(range(args.runs)):
        args.seed=SEEDS[RP]
        test_acc, best_val_acc, theta_0,time_run, teacher_test_acc, teacher_best_val_acc, best_model_state_dict = RunExp(args, dataset, data, Net, percls_trn, val_lb, RP)
        time_results.append(time_run)
        results.append([test_acc, best_val_acc, 0])
        teacher_results.append([teacher_test_acc, teacher_best_val_acc, 0])
        best_model_state_dicts.append(best_model_state_dict)
        print(f'run_{str(RP+1)} \t test_acc: {test_acc:.4f}')
        if args.net == 'BernNet':
            print('Theta:', [float('{:.4f}'.format(i)) for i in theta_0])

    run_sum=0
    epochsss=0
    for i in time_results:
        run_sum+=sum(i)
        epochsss+=len(i)

    print("each run avg_time:",run_sum/(args.runs),"s")
    print("each epoch avg_time:",1000*run_sum/epochsss,"ms")

    test_acc_mean, val_acc_mean, _ = np.mean(results, axis=0) * 100
    test_acc_std = np.sqrt(np.var(results, axis=0)[0]) * 100
    teacher_test_acc_mean, teacher_val_acc_mean, _ = np.mean(teacher_results, axis=0) * 100
    teacher_test_acc_std = np.sqrt(np.var(teacher_results, axis=0)[0]) * 100

    values=np.asarray(results)[:,0]
    uncertainty=np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values,func=np.mean,n_boot=1000),95)-values.mean()))
    teacher_values=np.asarray(teacher_results)[:,0]
    teacher_uncertainty=np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(teacher_values,func=np.mean,n_boot=1000),95)-teacher_values.mean()))

    #print(uncertainty*100)
    if args.wandb:
        wandb.log({'test': test_acc_mean, 'val': val_acc_mean, 'test_uncertainty':uncertainty, 'teacher_test': teacher_test_acc_mean, 'teacher_val': teacher_val_acc_mean, 'teacher_test_uncertainty':teacher_uncertainty})
    print(f'{gnn_name} on dataset {args.dataset}, in {args.runs} repeated experiment:')
    print(f'test acc mean = {test_acc_mean:.2f} ± {uncertainty*100:.2f}  \t val acc mean = {val_acc_mean:.2f}')
    if args.wandb:
        wandb.finish()

    save_dir = f'checkpoints/KD/{args.dataset}/{args.net}'
    os.makedirs(save_dir, exist_ok=True)
    cur_checkpoints = os.listdir(save_dir)
    if len(cur_checkpoints) == 0:
        os.makedirs(osp.join(save_dir, f'val_{val_acc_mean:.2f}_test_{test_acc_mean:.2f}'), exist_ok=True)
        for i in range(args.runs):
            torch.save(best_model_state_dicts[i], osp.join(save_dir, f'val_{val_acc_mean:.2f}_test_{test_acc_mean:.2f}', f'{i}_val_{results[i][1]:.4f}_test_{results[i][0]:.4f}.pth'))
    else:
        best_val = 0
        for cur_checkpoint in cur_checkpoints:
            strs = cur_checkpoint.split('_')
            val_acc = float(strs[1])
            test_acc = float(strs[3])
            if val_acc > best_val:
                best_val = val_acc
                best_checkpoint = cur_checkpoint
        if best_val < val_acc_mean:
            os.makedirs(osp.join(save_dir, f'val_{val_acc_mean:.2f}_test_{test_acc_mean:.2f}'), exist_ok=True)
            for i in range(args.runs):
                torch.save(best_model_state_dicts[i], osp.join(save_dir, f'val_{val_acc_mean:.2f}_test_{test_acc_mean:.2f}', f'{i}_val_{results[i][1]:.4f}_test_{results[i][0]:.4f}.pth'))
            for cur_checkpoint in cur_checkpoints:
                shutil.rmtree(osp.join(save_dir, cur_checkpoint), ignore_errors=False, onerror=None)
    
