from nesim.utils.json_stuff import load_json_as_dict
import pandas as pd
import os
"""
Resnet18: (all-topo)
- validation accuracies
- effective dims (mean)
- smoothness (mean)
DONE

Resnet50: (all-topo)
- validation accuracies
- effective dims (mean)
- smoothness (mean)
DONE

GPT-Neo-125M:
- perplexity score in openwebtext
- perplexity score in bookcorpus
- effective dim (mean of last 3 layers)
- smoothness
"""

VISION_DATA = {
    "model_name" : [],
    "imagenet_val_acc": [],
    "effective_dim": [],
    "smoothness": [],
    "effective_dim_marchenko": []
}

def parse_effective_dim_data(data:dict):
    # find the mean effective dim for each model
    result = {}

    for item in data:
        model_name = item["model_name"]
        effective_dims = [
            x["ed"]
            for x in item["effective_dims"]
        ]
        mean_effective_dim = sum(effective_dims)/len(effective_dims)
        result[model_name] = mean_effective_dim
    return result    

resnet18_val_acc = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet18/validation_acc/result_simple.json"
)
resnet18_val_acc["eshed"] = 0.439
resnet18_val_acc["LLCNN-G"] = 0.53


resnet18_effective_dim = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet18/effective_dimensionality/results.json"
)
resnet18_effective_dim_marchenko = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet18/effective_dimensionality/results.json"
)

resnet18_smoothness = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet18/pouya_smoothness/smoothness_values.json"
)
keys_to_delete = []
for key in resnet18_smoothness:
    if key.startswith("All topo\n$\\tau"):
        keys_to_delete.append(key)

for key in keys_to_delete:
    del resnet18_smoothness[key]

resnet18_smoothness["eshed"] =  0.8160919540229886
resnet18_smoothness["LLCNN-G"] = 0.7931034482758621

resnet18_effective_dim = parse_effective_dim_data(data=resnet18_effective_dim)
resnet18_effective_dim_marchenko = parse_effective_dim_data(data=resnet18_effective_dim_marchenko)
resnet18_model_names = list(resnet18_effective_dim.keys())

for model_name in resnet18_model_names:
    VISION_DATA['model_name'].append("resnet18_"+ model_name)
    VISION_DATA['imagenet_val_acc'].append(resnet18_val_acc[model_name])
    VISION_DATA['effective_dim'].append(resnet18_effective_dim[model_name])
    VISION_DATA['effective_dim_marchenko'].append(resnet18_effective_dim_marchenko[model_name])
    VISION_DATA['smoothness'].append(resnet18_smoothness[model_name])

# now resnet50
resnet50_val_acc = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet50/validation_acc/result_simple.json"
)
resnet50_effective_dim = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet50/effective_dimensionality/results.json"
)
resnet50_effective_dim_marchenko = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet50/effective_dimensionality/results.json"
)
resnet50_smoothness = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet50/pouya_smoothness/smoothness_values.json"
)
resnet50_effective_dim = parse_effective_dim_data(data=resnet50_effective_dim)
resnet50_effective_dim_marchenko = parse_effective_dim_data(data=resnet50_effective_dim_marchenko)

keys_to_delete = []
for key in resnet18_smoothness:
    if key.startswith("All topo\n$\\tau"):
        keys_to_delete.append(key)

for key in keys_to_delete:
    del resnet18_smoothness[key]

resnet50_model_names = list(resnet50_effective_dim.keys())

for model_name in resnet50_model_names:
    VISION_DATA['model_name'].append("resnet50_"+ model_name)
    VISION_DATA['imagenet_val_acc'].append(resnet50_val_acc[model_name])
    VISION_DATA['effective_dim'].append(resnet50_effective_dim[model_name])
    VISION_DATA['effective_dim_marchenko'].append(resnet50_effective_dim_marchenko[model_name])
    VISION_DATA['smoothness'].append(resnet50_smoothness[model_name])

print(VISION_DATA)
df  = pd.DataFrame(VISION_DATA)
df.to_csv("vision_results.csv")


NUM_LATE_LAYERS = 3
LANGUAGE_DATA = {
    "model_name": [],
    "perplexity_openwebtext": [],
    "effective_dim_late_layers": [],
    "smoothness_late_layers": []
}
lang_model_names = [
    "baseline",
    "topo_1",
    "topo_5",
    "topo_10",
    "topo_50"
]
lang_model_perplexity = {
    # "untrained": 10.8564,
    "baseline": 4.7252,
    "topo_1": 4.9957,
    "topo_5": 4.6094,
    "topo_10": 4.6843,
    "topo_50": 4.7816
}

lang_model_effective_dims = {}

for name in lang_model_names:
    result_single_model = list(
        load_json_as_dict(
            os.path.join(
                "/home/XXXX-4/repos/nesim/experiments/gpt_neo_125m/effective_dimensionality/results",
                f"{name}.json"
            )
        ).values()
    )
    ed_values = result_single_model[-NUM_LATE_LAYERS:]
    mean_ed = sum(ed_values) / len(ed_values)
    lang_model_effective_dims[name] = mean_ed

lang_model_smoothness_results = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/figures/correlation_vs_pairwise_distance/smoothness_data.json"
)
for model_name in lang_model_names:
    smoothness_values_single_model = list(lang_model_smoothness_results[model_name].values())
    mean_smooth = sum(smoothness_values_single_model[-NUM_LATE_LAYERS:])/len(smoothness_values_single_model[-NUM_LATE_LAYERS:])
    print(model_name, mean_smooth)
    lang_model_smoothness_results[model_name] = mean_smooth

for model_name in lang_model_names:
    LANGUAGE_DATA["model_name"].append(
        model_name
    )
    LANGUAGE_DATA["perplexity_openwebtext"].append(
        lang_model_perplexity[model_name]
    )
    LANGUAGE_DATA["effective_dim_late_layers"].append(
        lang_model_effective_dims[model_name]
    )
    LANGUAGE_DATA["smoothness_late_layers"].append(
        lang_model_smoothness_results[model_name]
    )

print(LANGUAGE_DATA)
df  = pd.DataFrame(LANGUAGE_DATA)
df.to_csv("language_results.csv")
