import argparse
import copy
import random
from tqdm import tqdm

import numpy as np

import torch
from torch.nn import Linear, ReLU, Dropout, ELU
import torch.nn.functional as F
from torch.utils.data import Subset
from torch.utils.data import WeightedRandomSampler
from torchmetrics import Accuracy, MeanSquaredError, MeanAbsoluteError, AUROC, F1Score

from torch_geometric.loader import DataLoader
from torch_geometric.nn import Sequential, GCNConv, GATConv, GAT, GCN, GIN
from torch_geometric.nn import global_add_pool, global_max_pool

from sklearn.metrics import f1_score, root_mean_squared_error as mse_score
from torchmetrics import MeanSquaredError, F1Score

from pathlib import Path
import sys

sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from datasetss import GraphFeaturizer, build_dataset


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0, help="seed")
    parser.add_argument("--split", type=int, default=0, help="split")
    parser.add_argument("--model_type", type=str, default="GIN", choices=["GCN","GIN", "GAT"])
    parser.add_argument("--save_path", type=str,default="model.pt")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--warmup_epochs", type=int, default=50)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--weight_decay", type=float, default=0.0001)
    parser.add_argument("--hidden_dim", type=int, default=32)
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--dropout", type=float, default=0.2, help="Dropout rate")
    parser.add_argument("--linear_dim", type=int, default=32)
    parser.add_argument("--data-set",default="cyp", type=str, choices=["sol", "cyp","herg","herg_K"])
    parser.add_argument("--task", type=str, choices=['classification', 'regression'], default='classification',)
    parser.add_argument("--patience", type=int, default=30, help="Early stopping patience")
    
    args = parser.parse_args()
    return args




def get_model(model_type, num_node_features, num_classes, hidden_dim, num_layers, linear_dim=32,dropout=0.2):
    if model_type == "GCN":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GCN(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        dropout=dropout,
                        norm="batch_norm",
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(hidden_dim, num_classes),
            ],
        )

    elif model_type == "GAT":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GAT(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        dropout=dropout,
                        norm="batch_norm",
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(hidden_dim, num_classes),
            ],
        )
    elif model_type == "GIN":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GIN(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        dropout=dropout,
                        norm="batch_norm",
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(hidden_dim, num_classes),
            ],
        )
    else:
        assert False
    return model


def train(model, optimizer, dataloader, device, task):
    model.train()
    total_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, batch=data.batch)
        # print(f"out shape: {out.shape}, data.y shape: {data.y.shape}")
        loss = F.cross_entropy(out, data.y.long()) if task == "classification" else F.mse_loss(out, data.y.float().reshape(-1,1))
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return float(loss) / len(dataloader)


@torch.no_grad()
def test(model, dataloader, device, task):
    model.eval()
    ys, preds = list(), list()
    total_loss = 0
    for data in dataloader:
        data = data.to(device)
        if task == "classification":
            out = model(data.x, data.edge_index, batch=data.batch)
            ys.append(data.y.cpu())
            preds.append(out.detach().cpu())
            total_loss += F.cross_entropy(out, data.y.long())
        else:
            out = model(data.x, data.edge_index, batch=data.batch)
            ys.append(data.y.float().cpu())
            preds.append(out.cpu())
            total_loss += F.mse_loss(out, data.y.float().reshape(-1,1))
    ys, preds = torch.cat(ys), torch.cat(preds)
    preds = preds.to(device)
    ys = ys.to(device)
    metrics={}
    if task == "classification":
        f= F1Score(task='multiclass', num_classes=2, average="weighted")
        auroc=AUROC(task='multiclass', num_classes=2)
        acc= Accuracy(task='multiclass', num_classes=2)
        f.to(device)
        auroc.to(device)
        acc.to(device)
        metrics['F1'] = f(preds, ys.long()).detach().cpu().item()
        metrics['AUROC'] = auroc(preds, ys.long()).detach().cpu().item()
        metrics['accuracy'] = acc(preds, ys.long()).detach().cpu().item()
    else:
        mse_= MeanSquaredError(squared=False)
        mse_.to(device)
        mae= MeanAbsoluteError()
        mae.to(device)
        metrics['rmse'] = mse_(preds, ys.unsqueeze(-1)).detach().cpu().item()
        metrics['mae'] = mae(preds, ys.unsqueeze(-1)).detach().cpu().item()

    return total_loss / len(dataloader), metrics


def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)


    if args.data_set == "sol":
        dataset_kwargs = {
            "data_set": args.data_set,
            "mean": -2.86,
            "std": 2.38,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": args.task,
            "split": args.split,
        }
    else:
        dataset_kwargs = {
            "data_set": args.data_set,
            "mean": 0.0,
            "std": 1.0,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": args.task,
            "split": args.split,
        }


    dataset_train, dataset_val, dataset_test = build_dataset(dataset_kwargs)
    featurizer=GraphFeaturizer(y_column='Y')

    train_set = featurizer(dataset_train, dataset_kwargs)
    val_set = featurizer(dataset_val, dataset_kwargs)
    test_set = featurizer(dataset_test, dataset_kwargs)
    if args.task == "classification":
        y = torch.cat([data.y for data in train_set])
        class_counts = torch.bincount(y.long())
        class_weights = 1.0 / class_counts.float()
        sample_weights = class_weights[y.long()]
        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
        dataloader_train = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True, sampler=sampler)
    else:
        sampler = None
        dataloader_train = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,drop_last=True, sampler=sampler)
    dataloader_val = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
    dataloader_test = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_args = {
        "model_type": args.model_type,
        "num_classes": 2 if args.task == "classification" else 1,
        "hidden_dim": args.hidden_dim,
        "num_layers": args.num_layers,
        "linear_dim": args.linear_dim,
        "num_node_features": train_set[0].x.shape[1],
        "dropout": args.dropout,
    }

    model = get_model(**model_args).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    pbar = tqdm(range(1, args.epochs + 1))
    min_loss_train, min_loss_val, min_loss_test = float("inf"), float("inf"), float("inf")
    best_state_dict = copy.deepcopy(model.state_dict())
    count = 0

    for epoch in pbar:
        train(model, optimizer, dataloader_train, device, args.task)
        loss_train, metric_train = test(model, dataloader_train, device, args.task)
        loss_val, metric_val = test(model, dataloader_val, device, args.task)
        loss_test, metric_test = test(model, dataloader_test, device, args.task)
        if  (loss_val < min_loss_val):
            min_loss_train, min_loss_val, min_loss_test = loss_train, loss_val, loss_test
            best_state_dict = copy.deepcopy(model.state_dict())
            count = 0
        else:
            if epoch > args.warmup_epochs:
                count += 1
                
        if count > args.patience:
            print(f"Early stopping at epoch {epoch}")
            break
            
        if args.task == "classification":
            pbar.set_description(
                f"e:{epoch} | train l:{loss_train:.4f} f1:{metric_train['F1']:.4f} | val l:{loss_val:.4f} f1:{metric_val['F1']:.4f} | test l:{loss_test:.4f} f1:{metric_test['F1']:.4f}"
            )
        else:
            pbar.set_description(
            f"e:{epoch} | train l:{loss_train:.4f} mse:{metric_train['rmse']:.4f} | val l:{loss_val:.4f} mse:{metric_val['rmse']:.4f} | test l:{loss_test:.4f} mse:{metric_test['rmse']:.4f}"
        )
    pbar.close()
    model.load_state_dict(best_state_dict)
    model.eval()
    _, final_metric = test(model, dataloader_test, device, args.task)
    # print(f"Final metric: {final_metric['F']:.4f}")
    if args.task == 'classification':
        print(f"Final metric: {final_metric['F1']:.4f}")
    else:
        print(f"Final metric: {final_metric['rmse']:.4f}")
    import os
    import csv
    

    
    final_metric['data_set'] = args.task
    final_metric['task'] = args.task
    final_metric['batch'] = args.batch_size
    final_metric['dropout'] = args.dropout
    final_metric['epoch']= args.epochs
    final_metric['stop_epoch'] = epoch
    final_metric['hidden']= args.hidden_dim
    final_metric['lr']= args.lr
    final_metric['layers'] = args.num_layers
    final_metric['model_type'] = args.model_type
    final_metric['split'] = args.split


    csv_path =  f"test_random_{args.task}_{args.model_type}.csv"
    file_exists = os.path.exists(csv_path)
    with open(csv_path, mode='a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=final_metric.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(final_metric)
    
    if args.task == "classification":
        torch.save(
            {
                "state_dict": model.state_dict(),
                "model_args": model_args,
                "args": vars(args),
                "f1": final_metric['F1'],
                "roc_auc": final_metric['AUROC'],
                "accuracy": final_metric['accuracy'],
            },
            args.save_path,
        )
    else:
        torch.save(
            {
                "state_dict": model.state_dict(),
                "model_args": model_args,
                "args": vars(args),
                "mae": final_metric['mae'],
                "rmse": final_metric['rmse'],
            },
            args.save_path,
        )
    print(f"Saved to {args.save_path}")


if __name__ == "__main__":
    main()
