import os

import yaml
from ml_collections import ConfigDict

import torch
from torch.utils.data import DataLoader, Subset

from bin_cp.experiments.image_utils.cifar_resnet import ResNet
from bin_cp.experiments.image_utils.architectures import get_architecture
from bin_cp.experiments.image_utils.image_datasets import get_dataset
from bin_cp.helpers.lightner import ModelManager, Output
from bin_cp.robust.smoothing import standard_l2_norm
from bin_cp.helpers.storage import smooth_prediction_filename

from sacred import Experiment
ex = Experiment('SmoothPredictions')

@ex.config
def config():
    dataset_name = "cifar10"
    model_sigma = 0.25
    n_datapoints = 2048
    smoothing_sigma = 0.25
    n_samples = 10000
    r=0.0
    attack = "pgd"
    attack_conf = None
    recompute = False
    save=True

def load_attack(attack_name, smoothing_sigma=None):
    if attack_name == "pgd":
        from torchattacks import PGDL2
        return PGDL2, {"steps":100}
    if attack_name == "pgd_rs":
        from torchattacks import PGDRSL2
        return PGDRSL2, {"steps":100, "noise_sd": smoothing_sigma}

@ex.automain
def run(dataset_name, model_sigma, n_datapoints, smoothing_sigma, n_samples, attack, r, attack_conf, recompute, save):
    # Loading and processing configs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    general_config = yaml.safe_load(open("../conf/general.yaml", "r"))
    conf = ConfigDict(general_config["general"])
    default_models = general_config["models"]
    model_name = default_models[dataset_name]

    # Loading model
    if model_sigma == 0:
        model_file = os.path.join(conf.models_dir, dataset_name, model_name, f"noise_0.00", "checkpoint.pth.tar")
    else:
        model_file = os.path.join(conf.models_dir, dataset_name, model_name, f"noise_{model_sigma}", "checkpoint.pth.tar")
    model_dict = torch.load(model_file)
    model = get_architecture(model_dict["arch"], dataset_name)
    model.load_state_dict(model_dict["state_dict"])
    model_obj = ModelManager(model, device=device)

    # Loading dataset
    dataset = get_dataset('cifar10', 'test', root=conf.dataset_dir)
    print(f"dataset size = {len(dataset)}")
    subset_indices = list(range(0, n_datapoints, ))
    dataset = Subset(dataset, subset_indices)
    print(f"dataset size = {len(dataset)}")

    test_dataset = DataLoader(dataset, batch_size=128, shuffle=False, pin_memory=True)
    
    # Creating or loading logits file
    logits_file_name = smooth_prediction_filename(dataset_name=dataset_name,
        model_sigma=model_sigma,
        n_datapoints=n_datapoints,
        smoothing_sigma=smoothing_sigma,
        attack=attack,
        n_samples=n_samples, r=r)

    try:
        if recompute:
            print("file_name = ", conf.logits_dir, logits_file_name)
            raise FileNotFoundError("Recompute is set to True, so we will recompute the logits.")
        clean_d = torch.load(os.path.join(conf.logits_dir, logits_file_name))
        y_pred = clean_d["y_pred"]
        logits = clean_d["logits"]
        y_true = clean_d["y_true"]
        prediction = Output(y_pred=y_pred, logits=logits, y_true=y_true)
        print("Loaded logits from file")
    except Exception as e:
        print(f"Error loading logits from file: {e}")
        print("Computing logits")
        adv, adv_config = load_attack(attack, smoothing_sigma=smoothing_sigma)
        prediction = model_obj.smooth_adv_predict(
            test_dataset, n_samples=n_samples, smoothing_function=lambda x: standard_l2_norm(x, sigma=smoothing_sigma),
            adv_class=adv, r=r, adv_conf=adv_config)
        if save:
            torch.save({
                "y_pred":prediction.y_pred, "logits": prediction.logits, "y_true": prediction.y_true
                }, os.path.join(conf.logits_dir, logits_file_name))
    acc = (prediction.y_pred == prediction.y_true).float().mean().item()
    print(f"Accuracy = {acc:.3f}")
