from nesim.eval.resnet import EvalSuite
from nesim.utils.json_stuff import load_json_as_dict
from nesim.experiments.resnet import create_val_loader
from nesim.utils.dimensionality import EffectiveDimensionality
from nesim.eval.resnet import load_resnet18_checkpoint
from nesim.utils.json_stuff import dict_to_json
from nesim.utils.pca import get_tensor_principal_components, find_num_components_explaining_variance

import argparse

# Initialize parser
parser = argparse.ArgumentParser(description="Example script with --marchenko flag")

# Add --marchenko argument
parser.add_argument('--marchenko', action='store_true', help='Use Marchenko flag', required = False)

# Parse arguments
args = parser.parse_args()

MAX_NUM_BATCHES = None
max_num_samples  = 20_000
spatial_pooling = "max"

layer_names = load_json_as_dict(
    "/home/mdeb6/repos/nesim/training/imagenet/resnet18/layer_names.json"
)

effective_dim = EffectiveDimensionality(
    flatten = False,
    batch_size=128,
    progress = False,
    marchenko=args.marchenko
)

val_dataloader = create_val_loader(
        val_dataset="/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv",
        num_workers=16,
        batch_size=128,
        resolution=224, 
        distributed=False, 
        gpu = 0,
        shuffled=True
)
eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

model_names = [
    "baseline_scale_None_shrink_factor_3.0",
    "all_topo_scale_0.5_shrink_factor_3.0",
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_5_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
    "all_topo_scale_20.0_shrink_factor_3.0",
    "all_topo_scale_50.0_shrink_factor_3.0",
    "eshed"
    # "end_topo_scale_0.5_shrink_factor_3.0",
    # "end_topo_scale_1.0_shrink_factor_3.0",
    # "end_topo_scale_5.0_shrink_factor_3.0",
    # "end_topo_scale_10.0_shrink_factor_3.0",
    # "end_topo_scale_50.0_shrink_factor_3.0",
]

layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet18/layer_names.json"
)


results = []
for model_name in model_names:
    selected_layer_names = layer_names["all_conv_layers"]
    if model_name != "eshed":
        model = load_resnet18_checkpoint(
            checkpoints_folder= "/home/mdeb6/repos/nesim/training/imagenet/resnet18/checkpoints",
            model_name=model_name,
            epoch="final"
        )
        model.eval()
    else:
        #gdown https://drive.google.com/file/d/12CqTtPSuI66nJHllP4HlLNwq8BKmsHtk/view?usp=share_link
        from nesim.eval.eshed import load_eshed_checkpoint
        model = load_eshed_checkpoint("imagenet_default_tdann.torch")
        model.eval().to(device="cuda:0")
        selected_layer_names = [
            "base_model."+x
            for x in selected_layer_names
        ]

    hook_outputs, labels = eval_suite.get_hook_outputs(
        model=model,
        layer_names=selected_layer_names,
        max_num_batches=None,
        spatial_pooling=spatial_pooling
    )

    data = []

    for layer_name in hook_outputs:
        single_layer_ouputs = hook_outputs[layer_name].float()
        single_layer_ouputs = single_layer_ouputs[:max_num_samples,:]

        pca_output = get_tensor_principal_components(
            tensor=single_layer_ouputs.to("cuda:0"),
            n_components=None
        )

        ed = effective_dim.compute(
            x = single_layer_ouputs
        )
        print(f"Model: {model_name} Layer: {layer_name} ED: {ed}")
        
        data.append(
            {
                "layer_name": layer_name,
                "ed": ed,
                "num_components_explaining_most_of_variance": {
                    "component_variances": pca_output.component_variances.cpu().tolist()
                }
            }
        )
        del single_layer_ouputs
    
    if "baseline" not in model_name:
        if model_name != "eshed":
            topo_scale = float(model_name.split('_')[3])
        else:
            topo_scale = "eshed"
    else:
        topo_scale = "baseline"
    results.append(
        {
            "model_name": model_name,
            "topo_scale": topo_scale,
            "effective_dims": data,
            "spatial_pooling":spatial_pooling
        }
    )
    del hook_outputs
    del model
        
if args.marchenko:
    dict_to_json(
        dictionary=results,
        filename="results_marchenko.json"
    )
else:
    dict_to_json(
        dictionary=results,
        filename="results.json"
    )