import argparse
import copy
import random
from tqdm import tqdm

import numpy as np

import torch
from torch.nn import Linear, ReLU, Dropout, ELU
import torch.nn.functional as F
from torch.utils.data import Subset
from torch.utils.data import WeightedRandomSampler

from torch_geometric.loader import DataLoader
from torch_geometric.nn import Sequential, GCNConv, GATConv, GAT, GCN, GIN
from torch_geometric.nn import global_add_pool, global_max_pool

from sklearn.metrics import f1_score
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from datasetss import SyntheticGraphFeaturizer, build_dataset
from architectures import SEALNetwork
from sklearn.metrics import f1_score, root_mean_squared_error as mse_score
from torchmetrics import MeanSquaredError, F1Score
from torchmetrics import Accuracy, MeanSquaredError, MeanAbsoluteError, AUROC, F1Score


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0, help="seed")
    parser.add_argument('--data-set', default='rings-count', choices=['covid','sol', 'cyp', 'herg', 'herg_K', 'rings-count', 'rings-max','X','P','B','indole','PAINS',],
                        type=str, help='dataset type')
    parser.add_argument("--split", type=int, default=0, help="split")
    parser.add_argument("--model_type", type=str, default="GIN", choices=["GCN","GIN", "GAT"])
    parser.add_argument("--save_path", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--warmup_epochs", type=int, default=50)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--weight_decay", type=float, default=0.0001)
    parser.add_argument("--hidden_dim", type=int, default=32)
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--dropout", type=float, default=0.2, help="Dropout rate")
    parser.add_argument("--linear_dim", type=int, default=32)
    parser.add_argument("--patience", type=int, default=30, help="Early stopping patience")
    parser.add_argument('--task', default='classification',
                    type=str, choices=['regression', 'classification', 'multiclassification'],)

    args = parser.parse_args()
    return args



def get_model(model_type, num_node_features, num_classes, hidden_dim, num_layers, linear_dim=32,dropout=0.2):
    if model_type == "GCN":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GCN(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        dropout=dropout,
                        norm="batch_norm",
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(hidden_dim, num_classes),
            ],
        )

    elif model_type == "GAT":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GAT(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        dropout=dropout,
                        norm="batch_norm",
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(hidden_dim, num_classes),
            ],
        )
    elif model_type == "GIN":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GIN(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        dropout=dropout,
                        norm="batch_norm",
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(hidden_dim, num_classes),
            ],
        )
    else:
        assert False
    return model


def train(model, optimizer, dataloader, device):
    model.train()
    total_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, batch=data.batch)
        # print(f"Output shape: {out.shape}, Target shape: {data.y.shape}")
        loss = F.cross_entropy(out, data.y.long())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return float(loss) / len(dataloader)


@torch.no_grad()
def test(model, dataloader, device):
    model.eval()
    ys, preds = list(), list()
    total_loss = 0
    for data in dataloader:
        data = data.to(device)
        out = model(data.x, data.edge_index, batch=data.batch)
        ys.append(data.y.cpu())
        preds.append(out.detach().cpu())
        total_loss += F.cross_entropy(out, data.y.long())
    ys, preds = torch.cat(ys), torch.cat(preds)
    # print(f"Shape of ys: {ys.shape}, Shape of preds: {preds.shape}")
    preds = preds.to(device)
    ys = ys.to(device)
    metrics={}
    
    f= F1Score(task='multiclass', num_classes=2, average="weighted")
    auroc=AUROC(task='multiclass', num_classes=2)
    acc= Accuracy(task='multiclass', num_classes=2)
    f.to(device)
    auroc.to(device)
    acc.to(device)
    metrics['F1'] = f(preds, ys.long()).detach().cpu().item()
    metrics['AUROC'] = auroc(preds, ys.long()).detach().cpu().item()
    metrics['accuracy'] = acc(preds, ys.long()).detach().cpu().item()

    return total_loss / len(dataloader), metrics


def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)



    
    dataset_kwargs = {
        "data_set": args.data_set,
        "task": args.task,
        "mean": 0.0,
        "std": 1.0,
        "y_column": 'Y',
        "smiles_col": "Drug",
        "split": args.split,
    }


    dataset_train, dataset_val, dataset_test = build_dataset(dataset_kwargs)
    featurizer=SyntheticGraphFeaturizer(y_column='Y')

    train_set = featurizer(dataset_train, dataset_kwargs)
    val_set = featurizer(dataset_val, dataset_kwargs)
    test_set = featurizer(dataset_test, dataset_kwargs)
    y = torch.cat([data.y for data in train_set])
    class_counts = torch.bincount(y.long())
    class_weights = 1.0 / class_counts.float()
    sample_weights = class_weights[y.long()]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    dataloader_train = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True, sampler=sampler)
    dataloader_val = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
    dataloader_test = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_args = {
        "model_type": args.model_type,
        "num_classes": 2 ,
        "hidden_dim": args.hidden_dim,
        "num_layers": args.num_layers,
        "linear_dim": args.linear_dim,
        "num_node_features": train_set[0].x.shape[1],
        "dropout": args.dropout,
    }

    model = get_model(**model_args).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    pbar = tqdm(range(1, args.epochs + 1))
    max_f1_train, max_f1_val, max_f1_test = 0, 0, 0
    best_state_dict = copy.deepcopy(model.state_dict())
    count = 0
    min_loss_train, min_loss_val, min_loss_test = float("inf"), float("inf"), float("inf")

    for epoch in pbar:
        train(model, optimizer, dataloader_train, device)
        loss_train, metric_train = test(model, dataloader_train, device)
        loss_val, metric_val = test(model, dataloader_val, device)
        loss_test, metric_test = test(model, dataloader_test, device)
        if  (loss_val < min_loss_val):
            min_loss_train, min_loss_val, min_loss_test = loss_train, loss_val, loss_test
            best_state_dict = copy.deepcopy(model.state_dict())
            count = 0
        else:
            if epoch > args.warmup_epochs:
                count += 1
        max_f1_train = max(max_f1_train, metric_train['F1'])
        max_f1_val = max(max_f1_val, metric_val['F1'])
        max_f1_test = max(max_f1_test, metric_test['F1'])

        if count > args.patience:
            print(f"Early stopping at epoch {epoch}")
            break
        pbar.set_description(
            f"e:{epoch} | train l:{loss_train:.4f} f1:{metric_train['F1']:.4f} | val l:{loss_val:.4f} f1:{metric_val['F1']:.4f} | test l:{loss_test:.4f} f1:{metric_test['F1']:.4f}"
        )
    
    pbar.set_description(f"Final {max_f1_train:.4f} {max_f1_val:.4f} {max_f1_test:.4f}")
    pbar.close()
    model.load_state_dict(best_state_dict)
    model.eval()
    import csv
    import os
    _, final_metric = test(model, dataloader_test, device)
    
    final_metric['task'] = args.task
    final_metric['batch'] = args.batch_size
    final_metric['dropout'] = args.dropout
    final_metric['epoch']= args.epochs
    final_metric['stop_epoch'] = epoch
    final_metric['hidden']= args.hidden_dim
    final_metric['lr']= args.lr
    final_metric['layers'] = args.num_layers
    final_metric['model_type'] = args.model_type
    final_metric['split'] = args.split


    csv_path =  f"test_{args.task}_{args.model_type}.csv"
    file_exists = os.path.exists(csv_path)
    with open(csv_path, mode='a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=final_metric.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(final_metric)

    
    print(f"Final metric: {final_metric['F1']:.4f}")
    torch.save(
        {
            "state_dict": model.state_dict(),
            "model_args": model_args,
            "args": vars(args),
            "f1": final_metric['F1'],
            "roc_auc": final_metric['AUROC'],
            "accuracy": final_metric['accuracy'],
        },
        args.save_path,
    )
    print(f"Saved to {args.save_path}")


if __name__ == "__main__":
    main()
