import os
import os.path as osp
import random

import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import DataLoader
from torch_geometric.datasets import PPI
import time

from gnn import *
from utils.early_stopping import EarlyStopping


def size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")/1e6
    print('Size (MB):', size)
    os.remove('temp.p')
    return size

def node_test(x, y, multi_label=False):
    if multi_label:
        micro_f1 = metrics.f1_score(
            y.cpu().detach().numpy(),
            (x > 0).cpu().detach().numpy(),
            average='micro')
        node_acc_count = micro_f1 * len(x)
    else:
        y = y.cpu()
        pred = torch.argmax(F.softmax(x, dim=1), dim=1).cpu()
        node_acc_count = metrics.accuracy_score(y,
                                                pred,
                                                normalize=False)

    return node_acc_count


def train(
        model,
        optimizer,
        loader,
        device,
        criterion,
        node_multi_label=True,
        mode="train"):

    if mode == "train":
        model.train()
    else:
        model.eval()

    total_loss = 0
    total_node = 0
    node_acc_count = 0
    data_count = 0

    for data in loader:
        data_count += data.num_graphs
        num_graphs = data.num_graphs

        data = data.to(device)

        if optimizer is not None:  # Only zero grad if optimizer exists (training mode)
            optimizer.zero_grad()

        if mode == "train":
            logits = model(x=data.x,
                           edge_index=data.edge_index,
                           batch=data.batch)
        else:
            with torch.no_grad():
                logits = model(x=data.x,
                               edge_index=data.edge_index,
                               batch=data.batch)

        loss = criterion(logits, data.y)

        node_acc_count += node_test(logits,
                                    data.y,
                                    node_multi_label)
        total_node += len(logits)

        total_loss += loss.item() * num_graphs

        if mode == "train" and optimizer is not None:
            loss.backward()
            optimizer.step()

    node_acc = float(node_acc_count) / total_node
    return total_loss / data_count, node_acc


def load_data(path):
    train_dataset = PPI(path, split='train')
    val_dataset = PPI(path, split='val')
    test_dataset = PPI(path, split='test')
    train_loader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0)
    val_loader = DataLoader(
        val_dataset,
        batch_size=2,
        shuffle=False,
        num_workers=0)
    test_loader = DataLoader(
        test_dataset,
        batch_size=2,
        shuffle=False,
        num_workers=0)

    return train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader


def trainer(
        model,
        logger,
        summary_file,
        train_loader,
        val_loader,
        test_loader,
        device,
        criterion,
        max_epoch=200,
        early_stopping=None,
        save_model=None):

    lr = 2e-4

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=0)
    scheduler = ReduceLROnPlateau(
        optimizer,
        "min",
        patience=100,
        verbose=True,
        factor=0.5,
        cooldown=30,
        min_lr=lr / 100)

    for epoch in range(0, max_epoch):
        train_loss, train_node_acc = train(
            model=model, optimizer=optimizer, loader=train_loader, device=device, mode="train", criterion=criterion)

        val_loss, val_node_acc = train(
            model=model, optimizer=optimizer, loader=val_loader, device=device, mode="val", criterion=criterion)

        test_loss, test_node_acc = train(
            model=model, optimizer=optimizer, loader=test_loader, device=device, mode="test", criterion=criterion)

        logger.write(
            f"{train_loss},{val_loss},{test_loss},{train_node_acc},{val_node_acc},{test_node_acc}\n")

        print(
            f"Epoch: {epoch}/{max_epoch}\nTrain:\t{train_loss}\t{train_node_acc}\nVal:\t{val_loss}\t{val_node_acc}\nTest:\t{test_loss}\t{test_node_acc}")

        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        scheduler.step(train_loss)

    # *** STEP 3: Save the model after training completes ***
    if save_model is not None:
        torch.save(model.state_dict(), save_model)
        print(f"Model saved to {save_model}")

    with open(summary_file, "a") as f:
        f.write(
            f"{train_loss},{val_loss},{test_loss},{train_node_acc},{val_node_acc},{test_node_acc},")
    
    return train_node_acc, val_node_acc, test_node_acc


def run_std(runs, file_name, **kwargs):
    train_accs, val_accs, test_accs = [], [], []
    for i in range(runs):
        kwargs["model"].reset_parameters()

        es = EarlyStopping(
            patience=20)

        train_node_acc, val_node_acc, test_node_acc = trainer(
            early_stopping=es, **kwargs)

        train_accs.append(train_node_acc)
        val_accs.append(val_node_acc)
        test_accs.append(test_node_acc)

    with open(file_name, "w") as std_file:
        std_file.write(f"{np.mean(train_accs)}, {np.std(train_accs)}\n")
        std_file.write(f"{np.mean(val_accs)}, {np.std(val_accs)}\n")
        std_file.write(f"{np.mean(test_accs)}, {np.std(test_accs)}\n")


def seed_everything(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


if __name__ == "__main__":
    h = 256

    _networks = [
        {"add": "S"}, {"add": "S"}, {"simple": "T"}
    ]

    file_path = "Supplement_materials/Code/src"
    extra_name = "results_ppi"
    

    for _net_class in [
        #"acgnn",
        #"gin",
        "acrgnn",
        # "acrgnn-single"
    ]:

        filename = f"{file_path}/logging/{extra_name}/ppi.mix"

        (_agg, _agg_abr) = list(_networks[0].items())[0]
        (_read, _read_abr) = list(_networks[1].items())[0]
        (_comb, _comb_abr) = list(_networks[2].items())[0]

        for comb_layers in [1]:
            if _net_class == "acgnn" and (
                    _read == "max" or _read == "add"):
                continue
            elif _net_class == "gin" and (_agg == "mean" or _agg == "max" or _comb == "mlp" or _read == "max" or _read == "add"):
                continue

            if _comb == "mlp" and comb_layers > 1:
                continue

            for l in range(1,11):
                print(_networks, _net_class, l, comb_layers)

                _log_file = f"{file_path}/logging/{extra_name}/ppi-{_net_class}-agg{_agg_abr}-read{_read_abr}-comb{_comb_abr}-cl{comb_layers}-L{l}-h{h}.log"

                with open(_log_file, "w") as log_file:
                    log_file.write(
                        "train_loss,val_loss,test_loss,train_acc,val_acc,test_acc\n")

                    train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader = load_data(
                        f"{file_path}/data/ppi")

                    if _net_class == "acgnn":
                        _model = ACGNN
                    elif _net_class == "acrgnn":
                        _model = ACRGNN
                    elif _net_class == "acrgnn-single":
                        _model = SingleACRGNN
                    elif _net_class == "gin":
                        _model = GIN

                    seed_everything(0)

                    if torch.cuda.is_available():
                        device = torch.device("cuda:0")
                    else:
                        device = torch.device("cpu")

                    model = _model(
                        input_dim=train_dataset.num_features,
                        hidden_dim=h,
                        output_dim=train_dataset.num_classes,
                        num_layers=l,
                        aggregate_type=_agg,
                        readout_type=_read,
                        combine_type=_comb,
                        combine_layers=comb_layers,
                        num_mlp_layers=2,
                        task="node",
                        truncated_fn=None)
                    model = model.to(device)

                    # *** STEP 1: Skip training if you want to avoid re-running the model ***
                    # run_std(
                    #     runs=10,
                    #     file_name=f"logging/results_ppi/{_net_class}-agg{_agg_abr}-read{_read_abr}-comb{_comb_abr}-cl{comb_layers}-L{l}-h{h}",
                    #     model=model,
                    #     logger=log_file,
                    #     summary_file=filename,
                    #     train_loader=train_loader,
                    #     val_loader=val_loader,
                    #     test_loader=test_loader,
                    #     device=device,
                    #     criterion=torch.nn.BCEWithLogitsLoss(),
                    #     max_epoch=500,
                    #     save_model=f"saved_models/ppi/{_net_class}-agg{_agg_abr}-read{_read_abr}-comb{_comb_abr}-cl{comb_layers}-L{l}-h{h}.pth")
                    
                    # *** STEP 2: Load the saved model and run testing/validation ***
                    
                    saved_model_path = f"{file_path}/saved_models/ppi/{_net_class}-agg{_agg_abr}-read{_read_abr}-comb{_comb_abr}-cl{comb_layers}-L{l}-h{h}.pth"
                    model.load_state_dict(torch.load(saved_model_path))
                    print(f"Loaded model from {saved_model_path}")

                    # Run train evaluation
                    size_original=size_of_model(model)
                    start_time = time.time() 
                    train_loss, train_node_acc = train(
                        model=model,
                        optimizer=None,  # No optimizer needed during evaluation
                        loader=train_loader,
                        device=device,
                        criterion=torch.nn.BCEWithLogitsLoss(),
                        mode="train")
                    elapsed_time_train = time.time() - start_time
                    print(f"Train Loss: {train_loss}, Train Accuracy: {train_node_acc}, Elapsed Time: {elapsed_time_train:.3f} sec")

                    # Run test evaluation 
                    start_time = time.time()
                    test_loss, test_node_acc = train(
                        model=model,
                        optimizer=None,  # No optimizer needed during evaluation
                        loader=test_loader,
                        device=device,
                        criterion=torch.nn.BCEWithLogitsLoss(),
                        mode="test")
                    elapsed_time_test = time.time() - start_time
                    print(f"Test Loss: {test_loss}, Test Accuracy: {test_node_acc}")

                    #Run validation evaluation
                    start_time = time.time()
                    val_loss, val_node_acc = train(
                        model=model,
                        optimizer=None,
                        loader=val_loader,
                        device=device,
                        criterion=torch.nn.BCEWithLogitsLoss(),
                        mode="val")
                    elapsed_time_val = time.time() - start_time
                    print(f"Validation Loss: {val_loss}, Validation Accuracy: {val_node_acc}")

                    # *** STEP 3: Apply post-training dynamic quantization ***
                    import torch.quantization
                    quantized_model = torch.quantization.quantize_dynamic(
                        model, {torch.nn.Linear}, dtype=torch.qint8)
                    print("Applied dynamic quantization.")
                    size_dptq=size_of_model(quantized_model)

                    # Run test evaluation on the quantized model
                    start_time = time.time()
                    train_loss_q, train_node_acc_q = train(
                        model=quantized_model,
                        optimizer=None,
                        loader=train_loader,
                        device=device,
                        criterion=torch.nn.BCEWithLogitsLoss(),
                        mode="train")
                    elapsed_time_train_dptq = time.time() - start_time
                    print(f"Quantized Model - train Loss: {train_loss_q}, train Accuracy: {train_node_acc_q}")

                    start_time = time.time()
                    test_loss_q, test_node_acc_q = train(
                        model=quantized_model,
                        optimizer=None,
                        loader=test_loader,
                        device=device,
                        criterion=torch.nn.BCEWithLogitsLoss(),
                        mode="test")
                    elapsed_time_test_dptq = time.time() - start_time
                    print(f"Quantized Model - Test Loss: {test_loss_q}, Test Accuracy: {test_node_acc_q}")

                    start_time = time.time()
                    val_loss_q, val_node_acc_q = train(
                        model=quantized_model,
                        optimizer=None,
                        loader=val_loader,
                        device=device,
                        criterion=torch.nn.BCEWithLogitsLoss(),
                        mode="val")
                    elapsed_time_val_dptq = time.time() - start_time
                    print(f"Quantized Model - Validation Loss: {val_loss_q}, Validation Accuracy: {val_node_acc_q}")
                    
                    results_file = f"{file_path}/for_analysis/{extra_name}/ppi_relu_results_for_appendix.log"
                    with open(results_file, "a") as f:
                        f.write(f"{_net_class}-L{l}-h{h}:"
                                f"Train Loss: {train_loss}, Train Acc: {train_node_acc}, Elapsed Time Train: {elapsed_time_train:.3f}, "
                                f"Test Loss: {test_loss}, Test Acc: {test_node_acc}, Elapsed Time Test: {elapsed_time_test:.3f}, "
                                f"Val Loss: {val_loss}, Val Acc: {val_node_acc}, Elapsed Time VAl: {elapsed_time_val:.3f}\n")
                    # Save the quantized model's results
                    quant_results_file = f"{file_path}/for_analysis/{extra_name}/ppi_relu_quantized_results_for_appendix.log"
                    with open(quant_results_file, "a") as qf:
                        qf.write(f"{_net_class}-L{l}-h{h}:" 
                                f"Train Loss: {train_loss_q}, Train Acc: {train_node_acc_q}, Elapsed Time Train: {elapsed_time_train_dptq:.3f}, "
                                f"Test Loss: {test_loss_q}, Test Acc: {test_node_acc_q}, Elapsed Time Test: {elapsed_time_test_dptq:.3f}, "
                                f"Val Loss: {val_loss_q}, Val Acc: {val_node_acc_q}, Elapsed Time Val: {elapsed_time_val_dptq:.3f}\n")
                    
                    size_results_file = f"{file_path}/for_analysis/{extra_name}/ppi_relu_results_size_for_appendix.log"
                    with open(size_results_file, "a") as f:
                        f.write(f"{_net_class}-L{l}-h{h}:"
                                f"Original model: {size_original}, Quantized model: {size_dptq}\n")
                        
                    time_results_file = f"{file_path}/for_analysis/{extra_name}/ppi_relu_results_time_for_appendix.log"
                    with open(time_results_file, "a") as f:
                        f.write(f"{_net_class}-L{l}-h{h}:"
                                f"Elapsed Time Train: {elapsed_time_train}, dPTQ Elapsed Time Train: {elapsed_time_train_dptq:.3f}, "
                                f"Elapsed Time Test: {elapsed_time_test}, dPTQ Elapsed Time Test: {elapsed_time_test_dptq:.3f}, "
                                f"Elapsed Time Val: {elapsed_time_val}, dPTQ Elapsed Time Val: {elapsed_time_val_dptq:.3f}\n"
                                )
            with open(filename, "a") as f:
                f.write("\n")
