import torch
import torch.nn.functional as F
import argparse
import random
import glob
import os
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from sklearn.metrics import f1_score, accuracy_score
from model.ugcformer import UGCFormer
from utils import DataLoader, dataset_split, adj_norm, set_seed, edge_index_to_coo, consis_loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, args, data, optimizer, split):
    model.train()
    optimizer.zero_grad()
    output_list=[]
    outputs_z, outputs_y, output = model(feat=data.x, topo=data.adj, adj_normal=data.adj_normal)

    output_list.append(outputs_z)
    output_list.append(outputs_y)

    loss_consis = consis_loss(output_list, args.temp_con)

    train_loss = F.cross_entropy(output[split['train']], data.y[split['train']]) \
                 + args.weight_con * loss_consis

    train_loss.backward()
    optimizer.step()
    if not args.fastmode:
        model.eval()
        _, _, output = model(feat=data.x, topo=data.adj, adj_normal=data.adj_normal)
        val_loss = F.cross_entropy(output[split['valid']], data.y[split['valid']])
    return train_loss.item(), val_loss.item()

def test(model, data, split):
    model.eval()
    _, _, output = model(feat=data.x, topo=data.adj, adj_normal=data.adj_normal)

    y_test = data.y[split['test']].detach().cpu().numpy()
    y_pred = output[split['test']].argmax(-1).detach().cpu().numpy()

    test_micro = f1_score(y_test, y_pred, average='micro')
    test_macro = f1_score(y_test, y_pred, average='macro')
    test_acc = accuracy_score(y_test, y_pred)

    return {
        'micro_f1': test_micro,
        'macro_f1': test_macro,
        'acc': test_acc
    }

def main(args):
    dataset = DataLoader(args)
    data = dataset[0].to(device)
    data.num_classes = dataset.num_classes
    data.name = args.dataset

    # topology representations
    data.adj = edge_index_to_coo(data.edge_index, data.num_nodes)

    # topology for updating the topology representations
    data.adj_normal = adj_norm(data.adj)
    if args.norm:
        data.adj = data.adj_normal

    activation = ({'relu': nn.ReLU, 'prelu': nn.PReLU,
                   'lrelu': nn.LeakyReLU, 'elu': nn.ELU,
                   'gelu': nn.GELU })[args.activation]()
    
    mi_list, ma_list = [], []

    for run_id in range(args.runs):
        if args.dataset not in ['cora', 'citeseer', 'pubmed']:
            seed = args.seed
        else:
            seed = random.randint(1, 100)
        set_seed(seed)
        model = UGCFormer(n=data.num_nodes,
                      nclass=data.num_classes,
                      nfeat=data.num_features,
                      nhidden=args.hidden,
                      nlayer=args.num_layers,
                      dropout=args.dropout,
                      activation=activation,
                      weight_attr=args.weight_attr,
                      num_heads=args.num_heads, 
                      use_weight=True
                ).to(device)

        optimizer = torch.optim.Adam(
            model.parameters(), weight_decay=args.weight_decay, lr=args.lr)

        min_val_loss = 1e9
        best_epoch = 0
        bad_counter = 0
        split = dataset_split(data, run_id)
        with tqdm(total=args.epochs, desc='(T)') as pbar:
            for epoch in range(0, args.epochs):
                train_loss, val_loss = train(model, args, data, optimizer, split)
                pbar.set_postfix({'Train Loss': train_loss, 'Val Loss': val_loss})
                pbar.update()
                if val_loss < min_val_loss:
                    min_val_loss = val_loss
                    torch.save(model.state_dict(), '{}.pkl'.format(epoch))
                    best_epoch = epoch
                    bad_counter =0
                else:
                    bad_counter += 1
                if bad_counter == args.patience:
                    break
                files = glob.glob('*.pkl')
                for file in files:
                    epoch_nb = int(file.split('.')[0])
                    if epoch_nb < best_epoch:
                        os.remove(file)

            files = glob.glob('*.pkl')
            for file in files:
                epoch_nb = int(file.split('.')[0])
                if epoch_nb > best_epoch:
                    os.remove(file)

        model.load_state_dict(torch.load('{}.pkl'.format(best_epoch)))
        result = test(model, data, split)
        mic, mac = result["micro_f1"], result["macro_f1"]
        print('{}-th run,'.format(run_id+1), 'micro-f1:{:.2f}%,'.format(mic*100), 'macor-f1:{:.2f}%;'.format(mac*100))
        mi_list.append(mic)
        ma_list.append(mac)

    mi_mean, mi_std = np.mean(mi_list), np.std(mi_list)
    ma_mean, ma_std = np.mean(ma_list), np.std(ma_list)

    print("Micro-F1, mean ± std: {:.2f}%±{:.2f}".format(mi_mean*100,mi_std*100))
    print("Macro-F1, mean ± std: {:.2f}%±{:.2f}".format(ma_mean*100,ma_std*100))

    filename = f'results/{args.dataset}.csv'
    print(f"Saving results to the'{filename}'")

    with open(f"{filename}", 'a+') as write_obj:
        write_obj.write(f"mi:{mi_mean*100:.2f} ± {mi_std*100:.2f},"
                        + f"ma:{ma_mean*100:.2f} ± {ma_std*100:.2f},"
                        + f"lr:{args.lr},"
                        + f"wd:{args.weight_decay},"
                        + f"num_layers:{args.num_layers},"
                        + f"weight_attr:{args.weight_attr},"
                        + f"weight_con:{args.weight_con},"
                        + f"num_heads:{args.num_heads},"
                        + f"hidden:{args.hidden},"
                        + f"dropout:{args.dropout},"
                        + f"activation:{args.activation},"
                        + f"runs:{args.runs},"
                        + f"epochs:{args.epochs},"
                        + f"patience:{args.patience},"
                        + f"norm:{args.norm},\n"
                        )

parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument('--fastmode', type=bool, default=False)
parser.add_argument('--data_dir', type=str, default="datasets/")
parser.add_argument('--seed', type=int, default=2024)
parser.add_argument('--runs', type=int, default=10)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--patience', type=int, default=50)
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight_decay', type=float, default=5e-5)
parser.add_argument('--activation', type=str, default='relu')
parser.add_argument('--norm', type=bool, default=True)
parser.add_argument('--hidden', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--num_heads', type=int, default=2)
parser.add_argument('--weight_con', type=float, default=1)
parser.add_argument('--temp_con', type=float, default=0.4)
parser.add_argument('--weight_attr', type=float, default=0.5)

args = parser.parse_args()
print("data:", args.dataset)
main(args)
