from nesim.eval.resnet import EvalSuite
from nesim.utils.json_stuff import load_json_as_dict
from nesim.experiments.resnet import create_val_loader
from nesim.utils.dimensionality import EffectiveDimensionality
from nesim.eval.resnet import load_resnet50_checkpoint
from nesim.utils.json_stuff import dict_to_json
from lightning import seed_everything
from nesim.utils.pca import get_tensor_principal_components, find_num_components_explaining_variance
import torch

import argparse

# Initialize parser
parser = argparse.ArgumentParser(description="Example script with --marchenko flag")

# Add --marchenko argument
parser.add_argument('--marchenko', action='store_true', help='Use Marchenko flag', required = False)

# Parse arguments
args = parser.parse_args()

seed_everything(0)

MAX_NUM_BATCHES = None
max_num_samples  = 20_000
spatial_pooling = "max"

layer_names = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/training/imagenet/resnet50/layer_names.json"
)

effective_dim = EffectiveDimensionality(
    flatten = False,
    batch_size=128,
    progress = False,
    marchenko=args.marchenko
)

val_dataloader = create_val_loader(
        val_dataset="/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv",
        num_workers=16,
        batch_size=128,
        resolution=224, 
        distributed=False, 
        gpu = 0,
        shuffled=True
)
eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

model_names = [
    "baseline_scale_None_shrink_factor_3.0",
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_30_shrink_factor_3.0",
]

layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet50/layer_names.json"
)


results = []
for model_name in model_names:


    model = load_resnet50_checkpoint(
        checkpoints_folder= "/research/XXXX-1/toponets_resnet50_imagenet_checkpoints",
        model_name=model_name,
        epoch="final"
    )
    model.eval()

    hook_outputs, labels = eval_suite.get_hook_outputs(
        model=model,
        layer_names=layer_names["all_conv_layers"],
        max_num_batches=None,
        spatial_pooling=spatial_pooling
    )

    data = []

    for layer_name in hook_outputs:
        single_layer_ouputs = hook_outputs[layer_name].float()
        single_layer_ouputs = single_layer_ouputs[:max_num_samples,:]

        # pca_output = get_tensor_principal_components(
        #     tensor=single_layer_ouputs.to("cuda:0"),
        #     n_components=None
        # )

        try:
            ed = effective_dim.compute(
                x = single_layer_ouputs
            )
        except torch._C._LinAlgError:
            effective_dim.use_numpy_backend = True
            ed = effective_dim.compute(
                x = single_layer_ouputs
            )
            effective_dim.use_numpy_backend=False

        print(f"Model: {model_name} Layer: {layer_name} ED: {ed}")
        
        data.append(
            {
                "layer_name": layer_name,
                "ed": ed,
                # "component_variances": pca_output.component_variances.cpu().tolist()
            }
        )
        del single_layer_ouputs
    
    if "baseline" not in model_name:
        topo_scale = float(model_name.split('_')[3])
    else:
        topo_scale = "baseline"
    results.append(
        {
            "model_name": model_name,
            "topo_scale": topo_scale,
            "effective_dims": data,
            "spatial_pooling":spatial_pooling
        }
    )
    del hook_outputs
    del model
        
if args.marchenko:
    dict_to_json(
        dictionary=results,
        filename="results_marchenko.json"
    )
else:
    dict_to_json(
        dictionary=results,
        filename="results.json"
    )