"""
Our code is based on GraphENS:
https://github.com/JoonHyung-Park/GraphENS
"""

import os.path as osp
import random
import torch
import torch.nn.functional as F
from nets import *
from data_utils import *
from args import parse_args
from models import *
from losses import *
from sklearn.metrics import balanced_accuracy_score, f1_score
import statistics
import numpy as np
import optuna
from functools import partial
import torch.nn.functional as F
from scipy.stats import wilcoxon
global class_num_list, idx_info, prev_out, aggregator
global data_train_mask, data_val_mask, data_test_mask
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Arg Parser ##
args = parse_args()

## Handling exception from arguments ##
assert not (args.warmup < 1 and args.tam)
# assert args.imb_ratio > 1
def regularizer(pred,gt,trial,args,imb_map):
    imb_map = imb_map.to(device)
    regularizer_coef = trial.suggest_float(
                    "regularizer_coef",
                    args.reg_coef[0],
                    args.reg_coef[1],
                    log=True
                )
    kl_power = trial.suggest_float(
                    "kl_power",
                    args.kl_power[0],
                    args.kl_power[1],
                    log=True
                )
    
    regularizer = args.regularizer
    unique_values = torch.unique(gt)
    class_ratio = []
    
    for val in (unique_values):
        class_ratio.append((gt == val).sum()/len(gt))
    class_ratio=(torch.tensor(class_ratio)).to(device)
    
    kl_coefs = torch.zeros(len(class_ratio))+2
    kl_coefs [class_ratio>class_ratio.mean()]=0.5
    kl_coefs = (1/kl_coefs)**1
    kl_coefs = kl_coefs.to(device)
    

    if regularizer=='kl':

        pred_mean=F.softmax(pred).mean(dim=0)

        temp=regularizer_coef*(pred_mean*torch.log((pred_mean+1e-7)*len(pred_mean)))
        return temp.sum()
    if regularizer=='kl_reverse':
        
        pred_mean=F.softmax(pred,-1).mean(dim=0)
        # print(imb_map**kl_power)
        temp=regularizer_coef*(pred_mean*torch.log((pred_mean+1e-7)/(imb_map**kl_power)))

        # temp=regularizer_coef*(pred_mean*torch.log((pred_mean+1e-7)*kl_coefs))
        return temp.sum()
    elif regularizer==None:
        
        return 0

    
def main(trial):
    global class_num_list, idx_info, prev_out, aggregator
    global data_train_mask, data_val_mask, data_test_mask
    
    ## Load Dataset ##
    dataset = args.dataset
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset)
    dataset = get_dataset(dataset, path, split_type='public')

    data = dataset[0]
    # raise ValueError(data.train_mask.sum())
    n_cls = data.y.max().item() + 1
    data = data.to(device)


    def train():
        global class_num_list, aggregator
        global data_train_mask, data_val_mask, data_test_mask

        model.train()
        optimizer.zero_grad()        

        output = model(data.x, data.edge_index)
        imb_map = torch.tensor([1/(data.y[data_train_mask]==cl_id).sum().item() for cl_id in data.y.unique()])
        imb_map = imb_map/imb_map.sum()
        
        ## Apply TAM ##
        output = adjust_output(args, output, data.edge_index, data.y, \
            data_train_mask, aggregator, class_num_list, epoch)
        
        loss = criterion(output, data.y[data_train_mask])
        loss += regularizer(output,data.y[data_train_mask],trial,args,imb_map)
        loss.backward()

        with torch.no_grad():
            model.eval()
            output = model(data.x, data.edge_index)
            val_loss= F.cross_entropy(output[data_val_mask], data.y[data_val_mask])
            val_loss += regularizer(output[data_val_mask],data.y[data_val_mask],trial,args,imb_map)

        optimizer.step()
        scheduler.step(val_loss)


    @torch.no_grad()
    def test():
        
        model.eval()
        logits = model(data.x, data.edge_index)
        accs, baccs, f1s = [], [], []

        for i, mask in enumerate([data_train_mask, data_val_mask]):
            pred = logits[mask].max(1)[1]
            y_pred = pred.cpu().numpy()
            y_true = data.y[mask].cpu().numpy()
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
            bacc = balanced_accuracy_score(y_true, y_pred)
            f1 = f1_score(y_true, y_pred, average='macro')

            accs.append(acc)
            baccs.append(bacc)
            f1s.append(f1)

        return accs, baccs, f1s


    ## Log for Experiment Setting ##
    setting_log = "Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, tam: {}".format(
        args.dataset, str(args.imb_ratio), args.net, str(args.n_layer), str(args.feat_dim), str(args.tam))

    repeatition = 10
    seed = 100
    avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], []
    for r in range(repeatition):

        ## Fix seed ##
        torch.cuda.empty_cache()
        seed += 1
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        random.seed(seed)
        np.random.seed(seed)

        if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']:
            data_train_mask, data_val_mask, data_test_mask = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone()
        else:
            data_train_mask, data_val_mask, data_test_mask = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone()
        
        ## Data statistic ##
        stats = data.y[data_train_mask]
        n_data = []
        for i in range(n_cls):
            data_num = (stats == i).sum()
            n_data.append(int(data_num.item()))
        idx_info = get_idx_info(data.y, n_cls, data_train_mask)
        
        class_num_list = n_data

        # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced
        imb_class_num = n_cls // 2
        new_class_num_list = []
        max_num = np.max(class_num_list[:n_cls-imb_class_num])
        for i in range(n_cls):
            if args.imb_ratio > 1 and i > n_cls-1-imb_class_num: #only imbalance the last classes
                new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i]))
            else:
                new_class_num_list.append(class_num_list[i])
        class_num_list = new_class_num_list

        if args.imb_ratio > 1:
            data_train_mask, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device)

        ## Model Selection ##
        if args.net == 'GCN':
            model = create_gcn(nfeat=dataset.num_features, nhid=args.feat_dim,
                            nclass=n_cls, dropout=0.5, nlayer=args.n_layer)
        elif args.net == 'GAT':
            model = create_gat(nfeat=dataset.num_features, nhid=args.feat_dim,
                            nclass=n_cls, dropout=0.5, nlayer=args.n_layer)
        elif args.net == "SAGE":
            model = create_sage(nfeat=dataset.num_features, nhid=args.feat_dim,
                            nclass=n_cls, dropout=0.5, nlayer=args.n_layer)
        else:
            raise NotImplementedError("Not Implemented Architecture!")

        ## Criterion Selection ##
        if args.loss_type == 'ce': # CE
            criterion = CrossEntropy()
        elif args.loss_type == 'bs':
            criterion = BalancedSoftmax(class_num_list)
        else:
            raise NotImplementedError("Not Implemented Loss!")

        model = model.to(device)
        criterion = criterion.to(device)

        # Set optimizer
        optimizer = torch.optim.Adam([
            dict(params=model.reg_params, weight_decay=5e-4),
            dict(params=model.non_reg_params, weight_decay=0),], lr=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                            factor = 0.5,
                                                            patience = 100,
                                                            verbose=False)

        # Train models
        best_val_acc_f1 = 0
        aggregator = MeanAggregation()
        patient=0
        for epoch in range(1, 700):

            train()
            accs, baccs, f1s = test()
            
            train_acc, val_acc = accs
            train_f1, val_f1 = f1s
            val_acc_f1 = (val_acc + val_f1) / 2.
            if val_acc_f1 > best_val_acc_f1:
                best_val_acc_f1 = val_acc_f1
                patient=0
            else:
                patient+=1
            if patient>100:
                break

        avg_val_acc_f1.append(best_val_acc_f1)



    avg_val_acc_f1 = statistics.mean(avg_val_acc_f1)

    

    return avg_val_acc_f1

def detailed_main(trial):

    global class_num_list, idx_info, prev_out, aggregator
    global data_train_mask, data_val_mask, data_test_mask
    ## Load Dataset ##
    dataset = args.dataset
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', dataset)
    dataset = get_dataset(dataset, path, split_type='public')

    data = dataset[0]
    # raise ValueError(data.train_mask.sum())
    n_cls = data.y.max().item() + 1
    data = data.to(device)


    def train():
        global class_num_list, aggregator
        global data_train_mask, data_val_mask, data_test_mask

        model.train()
        optimizer.zero_grad()        

        output = model(data.x, data.edge_index)
        imb_map = torch.tensor([1/(data.y[data_train_mask]==cl_id).sum().item() for cl_id in data.y.unique()])
        imb_map = imb_map/imb_map.sum()
        ## Apply TAM ##
        output = adjust_output(args, output, data.edge_index, data.y, \
            data_train_mask, aggregator, class_num_list, epoch)
        
        loss = criterion(output, data.y[data_train_mask])
        loss += regularizer(output,data.y[data_train_mask],trial,args,imb_map)
        loss.backward()

        with torch.no_grad():
            model.eval()
            output = model(data.x, data.edge_index)
            val_loss= F.cross_entropy(output[data_val_mask], data.y[data_val_mask])
            val_loss += regularizer(output[data_val_mask],data.y[data_val_mask],trial,args,imb_map)

        optimizer.step()
        scheduler.step(val_loss)


    @torch.no_grad()
    def test():
        
        model.eval()
        logits = model(data.x, data.edge_index)
        accs, baccs, f1s = [], [], []

        for i, mask in enumerate([data_train_mask, data_val_mask]):
            pred = logits[mask].max(1)[1]
            y_pred = pred.cpu().numpy()
            y_true = data.y[mask].cpu().numpy()
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
            bacc = balanced_accuracy_score(y_true, y_pred)
            f1 = f1_score(y_true, y_pred, average='macro')

            accs.append(acc)
            baccs.append(bacc)
            f1s.append(f1)

        return accs, baccs, f1s


    ## Log for Experiment Setting ##
    setting_log = "Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, tam: {}".format(
        args.dataset, str(args.imb_ratio), args.net, str(args.n_layer), str(args.feat_dim), str(args.tam))

    repeatition = 20
    seed = 100
    avg_val_acc_f1, avg_test_acc, avg_test_bacc, avg_test_f1 = [], [], [], []
    for r in range(repeatition):

        ## Fix seed ##
        torch.cuda.empty_cache()
        seed += 1
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        random.seed(seed)
        np.random.seed(seed)

        if args.dataset in ['squirrel', 'chameleon', 'Wisconsin']:
            data_train_mask, data_val_mask, data_test_mask = data.train_mask[:,r%10].clone(), data.val_mask[:,r%10].clone(), data.test_mask[:,r%10].clone()
        else:
            data_train_mask, data_val_mask, data_test_mask = data.train_mask.clone(), data.val_mask.clone(), data.test_mask.clone()

        ## Data statistic ##
        stats = data.y[data_train_mask]
        n_data = []
        for i in range(n_cls):
            data_num = (stats == i).sum()
            n_data.append(int(data_num.item()))
        idx_info = get_idx_info(data.y, n_cls, data_train_mask)
        class_num_list = n_data

        # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced
        imb_class_num = n_cls // 2
        new_class_num_list = []
        max_num = np.max(class_num_list[:n_cls-imb_class_num])
        for i in range(n_cls):
            if args.imb_ratio > 1 and i > n_cls-1-imb_class_num: #only imbalance the last classes
                new_class_num_list.append(min(int(max_num*(1./args.imb_ratio)), class_num_list[i]))
            else:
                new_class_num_list.append(class_num_list[i])
        class_num_list = new_class_num_list

        if args.imb_ratio > 1:
            data_train_mask, idx_info = split_semi_dataset(len(data.x), n_data, n_cls, class_num_list, idx_info, data.x.device)

        ## Model Selection ##
        if args.net == 'GCN':
            model = create_gcn(nfeat=dataset.num_features, nhid=args.feat_dim,
                            nclass=n_cls, dropout=0.5, nlayer=args.n_layer)
        elif args.net == 'GAT':
            model = create_gat(nfeat=dataset.num_features, nhid=args.feat_dim,
                            nclass=n_cls, dropout=0.5, nlayer=args.n_layer)
        elif args.net == "SAGE":
            model = create_sage(nfeat=dataset.num_features, nhid=args.feat_dim,
                            nclass=n_cls, dropout=0.5, nlayer=args.n_layer)
        else:
            raise NotImplementedError("Not Implemented Architecture!")

        ## Criterion Selection ##
        if args.loss_type == 'ce': # CE
            criterion = CrossEntropy()
        elif args.loss_type == 'bs':
            criterion = BalancedSoftmax(class_num_list)
        else:
            raise NotImplementedError("Not Implemented Loss!")

        model = model.to(device)
        criterion = criterion.to(device)

        # Set optimizer
        optimizer = torch.optim.Adam([
            dict(params=model.reg_params, weight_decay=5e-4),
            dict(params=model.non_reg_params, weight_decay=0),], lr=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                            factor = 0.5,
                                                            patience = 100,
                                                            verbose=False)

        # Train models
        best_val_acc_f1 = 0
        saliency, prev_out = None, None
        aggregator = MeanAggregation()
        patient=0
        for epoch in range(1, 2000):

            train()
            accs, baccs, f1s = test()
            val_acc, tmp_test_acc = accs
            val_f1, tmp_test_f1 = f1s
            val_acc_f1 = (val_acc + val_f1) / 2.
            if val_acc_f1 > best_val_acc_f1:
                best_val_acc_f1 = val_acc_f1
                test_acc = accs[1]
                test_bacc = baccs[1]
                test_f1 = f1s[1]
                patient=0
            else:
                patient+=1
            if patient>100:
                break

        avg_val_acc_f1.append(best_val_acc_f1)
        avg_test_acc.append(test_acc)
        avg_test_bacc.append(test_bacc)
        avg_test_f1.append(test_f1)

    ## Calculate statistics ##
    regularizer_coef = trial.suggest_float(
                    "regularizer_coef",
                    args.reg_coef[0],
                    args.reg_coef[1],
                    log=True
                )
    kl_power = trial.suggest_float(
                    "kl_power",
                    args.kl_power[0],
                    args.kl_power[1],
                    log=True
                )
    if args.loss_type == 'ce':
        with open('output_ce.md', 'a') as file:
            file.write(f'dataset:{args.dataset},net:{args.net},tam:{args.tam},regularizer:{args.regularizer},imb_ratio:{args.imb_ratio}\n,regularizer_coef:{regularizer_coef:.4f},kl_power:{kl_power:.4f}\n')
            file.write(f'avg_test_acc: {[float(f"{num:.4f}") for num in avg_test_acc]}\n')
            file.write(f'avg_test_bacc: {[float(f"{num:.4f}") for num in avg_test_bacc]}\n')
            file.write(f'avg_test_f1: {[float(f"{num:.4f}") for num in avg_test_f1]}\n\n\n')
    elif args.loss_type == 'bs':
        with open('output_bs.md', 'a') as file:
            file.write(f'dataset:{args.dataset},net:{args.net},tam:{args.tam},regularizer:{args.regularizer},imb_ratio:{args.imb_ratio}\n,regularizer_coef:{regularizer_coef:.4f},kl_power:{kl_power:.4f}\n')
            file.write(f'avg_test_acc: {[float(f"{num:.4f}") for num in avg_test_acc]}\n')
            file.write(f'avg_test_bacc: {[float(f"{num:.4f}") for num in avg_test_bacc]}\n')
            file.write(f'avg_test_f1: {[float(f"{num:.4f}") for num in avg_test_f1]}\n\n\n')
    else:
        raise NotImplementedError("Not Implemented Loss!")
if __name__ == '__main__':

    study = optuna.create_study(direction="maximize")

    study.optimize(
        main,
        n_trials=args.num_trial,
    )
    
    trial = study.best_trial
    detailed_main(trial)

    
    
   
        
        
