import argparse
import csv
import os
import torch
from safetensors import safe_open
from model_transformer.utils import one_hot_encode
from model_transformer.calibration_methods import TemperatureScaling, EnsembleTemperatureScaling, SplineCalibration, IsotonicRegressionCalibration, ParameterizedTemperatureScaling
from torchmetrics.classification import MulticlassCalibrationError, BinaryCalibrationError
from tabulate import tabulate
import pandas as pd
import pickle

def main(dataset_name):
    # grab all the files in the model folder
    results = []
    # Load the test data
    files = os.listdir(f"./data/CalibrationStudy/{dataset_name}")
    tensor_names = [file for file in files if file.endswith(".safetensors")]

    for tensor_name in tensor_names:
        # Load the safetensor file
        with safe_open(f"./data/CalibrationStudy/{dataset_name}/{tensor_name}", framework="pt") as f:
            # Load all tensors into a dictionary
            tensors = {}
            for key in f.keys():
                tensors[key] = f.get_tensor(key)

        # Create validation sets
        val_labels = tensors["labels_val"]
        val_logits = tensors["logits_val"]

        # Create test sets
        test_labels = tensors["labels_test"]
        test_logits = tensors["logits_test"]

        model_name = tensor_name.removesuffix(".safetensors")
        output_folder = f"outputs/{model_name}"
        os.makedirs(output_folder, exist_ok=True)

        ts = TemperatureScaling()
        ts.fit(val_logits, one_hot_encode(val_labels, val_logits.shape[1]))

        ets = EnsembleTemperatureScaling()
        ets.fit(val_logits, one_hot_encode(val_labels, val_logits.shape[1]))

        ir = IsotonicRegressionCalibration()
        ir.fit(val_logits, one_hot_encode(val_labels, val_logits.shape[1]))

        spline = SplineCalibration()
        spline.fit(val_logits, one_hot_encode(val_labels.numpy(), val_logits.shape[1]))

        pts = ParameterizedTemperatureScaling(top_k_logits=val_logits.shape[1])
        pts.fit(val_logits, one_hot_encode(val_labels, val_logits.shape[1]), verbose=True)

        ece_fn = MulticlassCalibrationError(num_classes=test_logits.shape[1], n_bins=15)
        bce_fn = BinaryCalibrationError(n_bins=15)

        uncalibrated_ece = ece_fn(preds=test_logits, target=test_labels)
        ts_calibrated_ece = ece_fn(preds=ts.transform(test_logits), target=test_labels)
        ets_calibrated_ece = ece_fn(preds=torch.tensor(ets.transform(test_logits)), target=test_labels)
        ir_calibrated_ece = ece_fn(preds=torch.tensor(ir.transform(test_logits)), target=test_labels)
        pts_calibrated_ece = ece_fn(preds=torch.tensor(pts.transform(test_logits)), target=test_labels)

        spline_calibrated, spline_labels = spline.transform(test_logits, one_hot_encode(test_labels, test_logits.shape[1]))
        spline_calibrated_ece = bce_fn(preds=torch.tensor(spline_calibrated), target=torch.tensor(spline_labels))

        table_res = [
            ["Uncalibrated", uncalibrated_ece.item()],
            ["Temperature Scaling", ts_calibrated_ece.item()],
            ["Ensemble Temperature Scaling", ets_calibrated_ece.item()],
            ["Isotonic Regression", ir_calibrated_ece.item()],
            ["Parameterized Temperature Scaling", pts_calibrated_ece.item()],
            ["Spline Calibration", spline_calibrated_ece.item()]
        ]

        headers = ["Calibration Method", "ECE"]
        table = tabulate(table_res, headers=headers, tablefmt="grid", floatfmt=".4f")

        print(f"{tensor_name}:")
        print(table)

        results.append((model_name, uncalibrated_ece.item(), ts_calibrated_ece.item(), ets_calibrated_ece.item(), ir_calibrated_ece.item(), pts_calibrated_ece.item(), spline_calibrated_ece.item()))

    # save results to csv
    pd.DataFrame(results, columns=["Model", "Uncalibrated", "Temperature Scaling", "Ensemble Temperature Scaling", "Isotonic Regression", "Parameterized Temperature Scaling", "Spline Calibration"]).to_csv(f"outputs/{dataset_name.replace('/', '_')}_results.csv", index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, help='Dataset name to evaluate')
    args = parser.parse_args()

    main(args.dataset)
