from nesim.eval.resnet import EvalSuite
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.sparsity.conv import DownsampledConv2d
import math
from nesim.utils.model_info import count_model_parameters
from nesim.utils.l1_sparsity import apply_l1_sparsity_to_model
from nesim.utils.json_stuff import load_json_as_dict
from nesim.losses.laplacian_pyramid.loss import LaplacianPyramidLoss
from nesim.experiments.resnet import create_val_loader
from nesim.eval.resnet import load_resnet18_checkpoint
from robustness import eval_robustness

MAX_NUM_BATCHES = None

def downsample_resnet(model, layer_names: list[str], downsample_factor = 9.0, max_loss = None):
    
    for layer_name in layer_names:
        original_layer = get_module_by_name(module=model, name=layer_name)

        if max_loss is  None:
                downsampled_layer = DownsampledConv2d(
                    conv_layer=original_layer,
                    factor_h=math.sqrt(downsample_factor),
                    factor_w=math.sqrt(downsample_factor)
                )
                setattr_pytorch_model(
                    model=model,
                    name=layer_name,
                    item=downsampled_layer
                )
        else:
            laplacian_pyramid_loss = LaplacianPyramidLoss(
                layer=original_layer,
                device=original_layer.weight.device,
                factor_h=[math.sqrt(downsample_factor)],
                factor_w=[math.sqrt(downsample_factor)]
            )
            loss = laplacian_pyramid_loss.get_loss().item()

            if loss < max_loss:
                downsampled_layer = DownsampledConv2d(
                    conv_layer=original_layer,
                    factor_h=math.sqrt(downsample_factor),
                    factor_w=math.sqrt(downsample_factor)
                )
                setattr_pytorch_model(
                    model=model,
                    name=layer_name,
                    item=downsampled_layer
                )
            else:
                print(f"Not downsampling layer: {layer_name} because it has a high topo loss ({loss} > {max_loss})")
    
    return model
        

val_dataloader = create_val_loader(
        val_dataset="/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv",
        num_workers=16,
        batch_size=128,
        resolution=224, 
        distributed=False, 
        gpu = 0
)
eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

model_names = [
    "end_topo_scale_50.0_shrink_factor_3.0",
    # "end_topo_scale_10.0_shrink_factor_3.0",
    "end_topo_scale_5.0_shrink_factor_3.0",
    # "end_topo_scale_1.0_shrink_factor_3.0",
    # "all_topo_scale_50.0_shrink_factor_3.0",
    # "all_topo_scale_5_shrink_factor_3.0",
    "baseline_scale_None_shrink_factor_3.0"
]
layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet18/layer_names.json"
)
epoch = 'final'
topo_layer_names = layer_names["last_conv_layers_in_each_block"]
sparsity_values = [i/10 for i in range(10)]
sparsity_values.append(0.95)
sparsity_values.append(0.98)

results= {}

epsilon = 3/255

for model_name in model_names:
    results[model_name] = {}

    model = load_resnet18_checkpoint(
        checkpoints_folder= "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints",
        model_name=model_name,
        epoch=epoch
    )
    original_param_count = count_model_parameters(
        model
    )[0]
    model.eval()

    original_accuracy = eval_suite.compute_accuracy(
        model=model,
        max_num_batches=None,
        progress=True
    )
    print(f"[{model_name}] ORIGINAL ACC NO ADV: {original_accuracy}")
        
    robustness = eval_robustness(
        model=model, 
        epsilon = epsilon,
        dataloader = eval_suite.dataloader,
        max_num_batches=5
    )
    print(f"[Original] Model: {model_name} epsilon: {epsilon} robustness: {robustness} Parameters: {count_model_parameters(model)[1]}")

    results[model_name]["l1"] = []
    fraction_of_masked_weights = 0.95

    for epsilon in [1/255, 2/255, 3/255]:

        model = load_resnet18_checkpoint(
            checkpoints_folder= "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints",
            model_name=model_name,
            epoch=epoch
        )
        model, total_num_masked_weights = apply_l1_sparsity_to_model(model=model, fraction_of_masked_weights=fraction_of_masked_weights, layer_names=topo_layer_names, return_num_masked_weights=True)
        model.eval().half()
        """
        do adversarial eval here munna
        """

        robustness = eval_robustness(
            model=model, 
            epsilon = epsilon,
            dataloader = eval_suite.dataloader,
            max_num_batches=5
        )

        results[model_name]["l1"].append(
            {
                "fraction_of_masked_weights": fraction_of_masked_weights,
                "robustness": robustness,
                "parameters": count_model_parameters(model=model)[0] - total_num_masked_weights,
                "sparsity": fraction_of_masked_weights,
                "epsilon": epsilon
            }
        )
        
        print(f"total_num_masked_weights: {total_num_masked_weights} epsilon: {epsilon} robustness: {robustness}")

from nesim.utils.json_stuff import dict_to_json

dict_to_json(
    results,
    "results_epsilon.json"
)