import numpy as np
import pandas as pn
import torch
from tqdm import tqdm

model_names = [
    "densenet121_cifar10",
    "resnet18_cifar10",
    "resnet34_cifar10",
    "vgg16_bn_cifar10",
    "vit_base_patch16_224_in21k_ft_cifar10",
    "densenet121_cifar100",
    "resnet18_cifar100",
    "resnet34_cifar100",
    "vgg16_bn_cifar100",
    "vit_base_patch16_224_in21k_ft_cifar100",
    "densenet121.tv_in1k",
    "mobilenetv3_large_100.miil_in21k_ft_in1k",
    "mobilenetv3_small_100.lamb_in1k",
    "resnet101.tv_in1k",
    "resnet34.tv_in1k",
    "resnet50.tv_in1k",
    "vit_base_patch16_224.augreg_in21k_ft_in1k",
    "vit_large_patch16_224.augreg_in21k_ft_in1k",
    "vit_small_patch16_224.augreg_in21k_ft_in1k",
    "vit_tiny_patch16_224.augreg_in21k_ft_in1k",
]
# ,seed,model,ood_dataset,method_sc,method_ood,ood_ds_idx,risks_sc,coverages_sc,thrs_sc,ood_performance,ood_performance_with_CDF,scod_performance,tprs_scod_ood,fprs_scod_ood,risks_scod_ood,risks_num_scod_ood,risks_den_scod_ood,coverages_scod_ood,lbds_scod_ood,s_ood,s_ood_prob,labels_array_ood,labels_array_sc,datasets_idx_ood,datasets_idx_sc,scores_array_ood,scores_array_sc,idx_in_d,idx_out_d,s_ood_label
lens = 0
i = 0
all_dfs = []
for m in tqdm(model_names):
    print(m)
    df = torch.load(f"results/{m}/res.pt")
    lens += len(df)
    df = df[
        [
            "seed",
            "model",
            "ood_dataset",
            "method_sc",
            "method_ood",
            "ood_performance",
            "ood_performance_with_CDF",
            "scod_performance",
            "s_ood",
            "s_ood_prob",
            "s_ood_label",
            "risks_scod_ood",
            "coverages_scod_ood",
            # "lbds_scod_ood",
        ]
    ]
    df = df.drop_duplicates(subset=["seed", "model", "ood_dataset", "method_sc", "method_ood"], keep="last")
    # change dtypes to float32
    for col in df.columns:
        # print(col, type(df[col][0]))
        if isinstance(df[col][0], list):
            print("original type", type(df[col][0][0]), f"quantized col {col} to float32")
            df[col] = df[col].apply(lambda x: np.array(x, dtype=np.float32))
        if isinstance(df[col][0], np.ndarray):
            print("original type", type(df[col][0][0]), f"quantized col {col} to float32")
            df[col] = df[col].apply(lambda x: x.astype(np.float32))

    all_dfs.append(df)

print("concatenating...")
all_dfs = pn.concat(all_dfs, ignore_index=True)
# drop duplicates
print("saving to csv")
all_dfs.to_csv("results/all_results.csv", index=False)
print("saving to pt")
torch.save(all_dfs, "results/all_results.pt")
