import argparse
import csv
import os
import torch
from safetensors import safe_open
from model_transformer.utils import one_hot_encode
from model_transformer.calibration_methods import TemperatureScaling, EnsembleTemperatureScaling, SplineCalibration, IsotonicRegressionCalibration, ParameterizedTemperatureScaling
from torchmetrics.classification import MulticlassCalibrationError, BinaryCalibrationError
from tabulate import tabulate
import pandas as pd
import pickle

def main(dataset_name):
    # grab all the files in the model folder
    datasets = []
    # get corruptions from the folder names in the imagenet-c folder
    corruptions = [file for file in os.listdir(f"./data/CalibrationStudy/medmnistc/{dataset_name}") if file != ".DS_Store"]
    datasets.extend([f"medmnistc/{dataset_name}/{c}/{i}" for c in corruptions for i in range(1,6)])

    for dataset_name in datasets:
        results = []
        # Load the test data
        files = os.listdir(f"./data/CalibrationStudy/{dataset_name}")
        tensor_names = [file for file in files if file.endswith(".safetensors")]

        for tensor_name in tensor_names:
            # Load the safetensor file
            with safe_open(f"./data/CalibrationStudy/{dataset_name}/{tensor_name}", framework="pt") as f:
                # Load all tensors into a dictionary
                tensors = {}
                for key in f.keys():
                    tensors[key] = f.get_tensor(key)

            # Create validation sets
            val_labels = tensors["labels_val"]
            val_logits = tensors["logits_val"]

            # Create test sets
            test_labels = tensors["labels_test"]
            test_logits = tensors["logits_test"]

            model_name = tensor_name.removesuffix(".safetensors")
            output_folder = f"outputs/{dataset_name.split('/')[0]}/{model_name}"
            os.makedirs(output_folder, exist_ok=True)

            # Load or fit the calibration methods
            calibration_methods = {
                "ts": (TemperatureScaling, "temperature_scaling.pkl"),
                "ets": (EnsembleTemperatureScaling, "ensemble_temperature_scaling.pkl"),
                "spline": (SplineCalibration, "spline_calibration.pkl"),
                "ir": (IsotonicRegressionCalibration, "isotonic_regression.pkl"),
                "pts": (ParameterizedTemperatureScaling, "parameterized_temperature_scaling.pkl")
            }

            for method_name, (calibration_class, pickle_name) in calibration_methods.items():
                pickle_path = os.path.join(output_folder, pickle_name)
                
                if os.path.exists(pickle_path):
                    # Load existing calibration model
                    with open(pickle_path, 'rb') as f:
                        globals()[method_name] = pickle.load(f)
                    print(f"Loaded existing {method_name} from {pickle_path}")
                else:
                    # Fit new calibration model
                    calibration_model = calibration_class()
                    if method_name == "spline":
                        calibration_model.fit(val_logits, one_hot_encode(val_labels.numpy(), val_logits.shape[1]))
                    elif method_name == "pts":
                        calibration_model.top_k_logits = val_logits.shape[1]
                        calibration_model.fit(val_logits, one_hot_encode(val_labels, val_logits.shape[1]), verbose=True)
                    else:
                        calibration_model.fit(val_logits, one_hot_encode(val_labels, val_logits.shape[1]))
                    
                    # Save the fitted model
                    with open(pickle_path, 'wb') as f:
                        pickle.dump(calibration_model, f)
                    print(f"Fitted and saved new {method_name} to {pickle_path}")
                    
                    globals()[method_name] = calibration_model

            ece_fn = MulticlassCalibrationError(num_classes=test_logits.shape[1], n_bins=15)
            bce_fn = BinaryCalibrationError(n_bins=15)

            uncalibrated_ece = ece_fn(preds=test_logits, target=test_labels)
            ts_calibrated_ece = ece_fn(preds=ts.transform(test_logits), target=test_labels)
            ets_calibrated_ece = ece_fn(preds=torch.tensor(ets.transform(test_logits)), target=test_labels)
            ir_calibrated_ece = ece_fn(preds=torch.tensor(ir.transform(test_logits)), target=test_labels)
            pts_calibrated_ece = ece_fn(preds=torch.tensor(pts.transform(test_logits)), target=test_labels)

            spline_calibrated, spline_labels = spline.transform(test_logits, one_hot_encode(test_labels, test_logits.shape[1]))
            spline_calibrated_ece = bce_fn(preds=torch.tensor(spline_calibrated), target=torch.tensor(spline_labels))

            table_res = [
                ["Uncalibrated", uncalibrated_ece.item()],
                ["Temperature Scaling", ts_calibrated_ece.item()],
                ["Ensemble Temperature Scaling", ets_calibrated_ece.item()],
                ["Isotonic Regression", ir_calibrated_ece.item()],
                ["Parameterized Temperature Scaling", pts_calibrated_ece.item()],
                ["Spline Calibration", spline_calibrated_ece.item()]
            ]

            headers = ["Calibration Method", "ECE"]
            table = tabulate(table_res, headers=headers, tablefmt="grid", floatfmt=".4f")

            print(f"{tensor_name}:")
            print(table)

            results.append((model_name, uncalibrated_ece.item(), ts_calibrated_ece.item(), ets_calibrated_ece.item(), ir_calibrated_ece.item(), pts_calibrated_ece.item(), spline_calibrated_ece.item()))

        # save results to csv
        pd.DataFrame(results, columns=["Model", "Uncalibrated", "Temperature Scaling", "Ensemble Temperature Scaling", "Isotonic Regression", "Parameterized Temperature Scaling", "Spline Calibration"]).to_csv(f"outputs/{dataset_name.replace('/', '_')}_results.csv", index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, required=True)
    args = parser.parse_args()
    main(args.dataset_name)
