import os
import shutil
import torch
from tqdm import tqdm
import pandas as pd

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from interpretable_resnet_torchvision import InterpretableResNet50

DATASET_PATH = "CalTech-256/Dataset/ImageFolder_test"
COPY_PATH = 'CalTech-256/Dataset/confident_images'
target_conf = 0.95
conf_dict = {}

model = InterpretableResNet50(caltech256 = True)
model.model.load_state_dict(torch.load("CalTech-256/checkpoints_two_layer/best_model_epoch_24_acc_82.84.pth", map_location="cpu")["model_state_dict"])
model.eval()

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model.to(device)
model.eval()


dataset = ImageFolder(DATASET_PATH, transform=model.transforms)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
file_classes = list(map(lambda x: int(x[0].split('/')[-1].split('class')[1].split('_')[0]) - 1, dataset.imgs))

for idx, (image, _) in tqdm(enumerate(dataloader), desc="Getting confident images"):
    label = file_classes[idx]
    image = image.to(device)

    with torch.no_grad():
        top_pred, top_class = torch.topk(model(image).softmax(dim=1), k=1)
        if top_class.item() == label and top_pred.item() >= target_conf:
            conf_dict[dataset.imgs[idx][0]] = top_pred.item()
            shutil.copy(dataset.imgs[idx][0], os.path.join(COPY_PATH, dataset.imgs[idx][0].split('/')[-1]))
    
print(f"Found {len(conf_dict)} confident images")
pd.DataFrame({'Filename': list(conf_dict.keys()), 'Confidence': list(conf_dict.values())}).to_csv(os.path.join(COPY_PATH, f"confident_images_conf_{target_conf}.csv"))
