import torch
import torch.nn.functional as F
import argparse
import glob
import os
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import random
from sklearn.metrics import f1_score, accuracy_score
from model.brformer3 import BRFormer
from util import set_seed, DataLoader, dataset_split

def train(model, args, data, optimizer, split, epoch):
    torch.cuda.empty_cache()
    model.train()
    optimizer.zero_grad()
    out, spec_loss = model(data.x, data.edge_index, epoch)
    train_loss = F.nll_loss(out[split['train']], data.y[split['train']])
    loss = train_loss + args.beta_spec * spec_loss
    loss.backward()
    optimizer.step()
    model.eval()
    with torch.no_grad(): 
        out_val, _ = model(data.x, data.edge_index)
        val_loss = F.nll_loss(out_val[split['valid']], data.y[split['valid']]).item()
    return loss.item(), val_loss, spec_loss.item() 

def test(model, data, split, best_epoch): 
    model.load_state_dict(torch.load('{}.pkl'.format(best_epoch)))
    model.eval()
    with torch.no_grad():
        output, _ = model(data.x, data.edge_index)
        y_test = data.y[split['test']].detach().cpu().numpy()
        y_pred = output[split['test']].argmax(-1).detach().cpu().numpy()
        test_micro = f1_score(y_test, y_pred, average='micro')
        test_macro = f1_score(y_test, y_pred, average='macro')
        test_acc = accuracy_score(y_test, y_pred)
        return {
                'micro_f1': test_micro, 
                'macro_f1': test_macro, 
                'acc': test_acc
                }

def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = DataLoader(args.dataset)
    data = dataset[0].to(device)
    num_class = max(data.y.max().item() + 1, data.y.shape[0])
    data.name = args.dataset
    activation = ({'relu': nn.ReLU, 'prelu': nn.PReLU, 'lrelu': nn.LeakyReLU, 'elu': nn.ELU})[args.activation]

    mi_list, ma_list, acc_list = [], [], []
    for run_id in range(args.runs):
        seed = random.randint(1, 1000)
        set_seed(seed)
        model = BRFormer(
            input_dim=data.num_features,
            hidden_dim=args.hidden,
            output_dim=num_class,
            activation=activation,
            num_gnns=args.num_gnns, 
            num_trans=args.num_sa,
            num_heads=args.num_heads,
            dropout=args.dropout,
            graphconv=args.graphconv,
            k_blocks=args.k_blocks,
            g_blocks=args.g_blocks,          
            hard=args.hard,
            w_attn=args.w_attn,
            keep_ratio=args.keep_ratio
        ).to(device)
        split = dataset_split(data, run_id)
        model.init_anchors_from_labels(data.x, data.y, train_idx=split['train'])
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        min_val_loss = 1e9
        best_epoch = 0
        bad_counter = 0
        

        with tqdm(total=args.epochs, desc=f'(Run {run_id+1}/{args.runs})') as pbar:
            for epoch in range(args.epochs):
                loss, val_loss, spec_val = train(model, args, data, optimizer, split, epoch)
                pbar.set_postfix({'TrL':f'{loss:.3f}','ValL':f'{val_loss:.3f}','Spec':f'{spec_val:.3f}'})
                pbar.update()
                if val_loss < min_val_loss:
                    min_val_loss = val_loss
                    torch.save(model.state_dict(), '{}.pkl'.format(epoch)) 
                    best_epoch = epoch
                    bad_counter = 0
                else:
                    bad_counter += 1
                if bad_counter == args.patience:
                    break

                files = glob.glob('*.pkl') 
                for file in files:
                    epoch_nb = int(file.split('.')[0]) 
                    if epoch_nb < best_epoch:
                        os.remove(file) 

            files = glob.glob('*.pkl') 
            for file in files:
                epoch_nb = int(file.split('.')[0])
                if epoch_nb > best_epoch:
                    os.remove(file) 

        result = test(model, data, split, best_epoch) 
        mic, mac, acc = result["micro_f1"], result["macro_f1"], result["acc"]
        mi_list.append(mic) 
        ma_list.append(mac) 
        acc_list.append(acc) 

    mi_mean, mi_std = np.mean(mi_list), np.std(mi_list)
    ma_mean, ma_std = np.mean(ma_list), np.std(ma_list)
    ac_mean, ac_std = np.mean(acc_list), np.std(acc_list)

    print("ACC, mean ± std: {:.2f}%±{:.2f}".format(ac_mean * 100, ac_std * 100))

    os.makedirs("results", exist_ok=True)
    filename = f'results/{args.graphconv}_{args.dataset}.csv'
    print(f"Saving results to '{filename}'")
    with open(filename, 'a') as f:
        f.write(
            f"mi:{mi_mean*100:.2f} ± {mi_std*100:.2f},"
            f"ma:{ma_mean*100:.2f} ± {ma_std*100:.2f},"
            f"ac:{ac_mean*100:.2f} ± {ac_std*100:.2f},"
            f"seed:{args.seed},"
            f"w_attn:{args.w_attn},"
            f"keep_ratio:{args.keep_ratio},"
            f"graphconv:{args.graphconv},"
            f"sa_backend:{args.sa_backend},"
            f"hard:{args.hard},"
            f"wd:{args.weight_decay},"
            f"activation:{args.activation},"
            f"hidden:{args.hidden},"
            f"num_gnns:{args.num_gnns},"
            f"num_sa:{args.num_sa},"
            f"num_heads:{args.num_heads},"
            f"dropout:{args.dropout},"
            f"k_blocks:{args.k_blocks},"
            f"g_blocks:{args.g_blocks},"
            f"beta_spec:{args.beta_spec},"
            f"epochs:{args.epochs},"
            f"patience:{args.patience},"
            f"lr:{args.lr},"
            f"runs:{args.runs}\n"
        )

if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    parser.add_argument('--fastmode', type=bool, default=False) 
    parser.add_argument('--seed', type=int, default=2026) 
    parser.add_argument('--runs', type=int, default=3) 
    parser.add_argument('--epochs', type=int, default=500) 
    parser.add_argument('--patience', type=int, default=50) 
    parser.add_argument('--dataset', type=str, default='photo')
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'prelu', 'elu', 'lrelu'])
    parser.add_argument('--hidden', type=int, default=64)
    parser.add_argument('--num_gnns', type=int, default=3)
    parser.add_argument('--num_sa', type=int, default=1)
    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--w_attn', type=float, default=0.3)
    parser.add_argument('--keep_ratio', type=float, default=0.8)
    parser.add_argument('--graphconv', type=str, default='sgc') 
    parser.add_argument('--sa_backend', type=str, default='bdg', choices=['vanilla', 'bdg']) 
    parser.add_argument('--k_blocks', type=int, default=32)
    parser.add_argument('--g_blocks', type=int, default=10)
    parser.add_argument('--beta_spec', type=float, default=1)
    parser.add_argument('--hard', type=bool, default=False)
    args = parser.parse_args()
    main(args) 

