import os
import torch
import torchvision
import numpy as np
import random
from tqdm import tqdm
import argparse
import pandas as pd

from PIL import Image
import matplotlib.pyplot as plt
from interpretable_resnet_torchvision import InterpretableResNet50
import torchvision.transforms as T



parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="IN-1k", choices=["IN-1k", "CT-256"])
parser.add_argument("--conf_threshold", type=float, default=0.65)
parser.add_argument("--num_samples", type=int, default=150)
args = parser.parse_args()

samples_data = {
    'file name': [],
    "classes": [],
    "confidences": [],
    "final class": []
}

ct_256_base_transforms = T.Compose([
    T.CenterCrop((224, 224)),
    T.Resize((232, 232)),
])

in_1k_base_transforms = T.Compose([
    T.Resize((256, 256)),
    T.CenterCrop(224),
])

if args.dataset == "IN-1k":
    DATA_DIR = 'ImageNet-onek/confident_images'
    SAVE_DIR = 'ImageNet-onek/grid_game_samples'
elif args.dataset == "CT-256":
    DATA_DIR = 'CalTech-256/Dataset/confident_images'
    SAVE_DIR = 'CalTech-256/Dataset/grid_game_samples'

if args.dataset == "IN-1k":
    model = InterpretableResNet50()
    model.eval()
elif args.dataset == "CT-256":
    model = InterpretableResNet50(caltech256=True)
    model.model.load_state_dict(torch.load("CalTech-256/checkpoints_two_layer/best_model_epoch_24_acc_82.84.pth", map_location="cpu")["model_state_dict"])
    model.eval()


if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model.to(device)
model.eval()

conf_threshold = args.conf_threshold
num_samples = args.num_samples

all_files = [os.path.join(DATA_DIR, file) for file in os.listdir(DATA_DIR) if file.endswith(('.jpeg', '.jpg'))]
random.seed(42)


while num_samples > 0:
    sample_files = random.sample(all_files, 4)
    if args.dataset == "IN-1k":
        classes = list(map(lambda x: int(x.split('/')[-1].split('class')[1].split('_')[0]), sample_files))
    else:
        classes = list(map(lambda x: int(x.split('/')[-1].split('class')[1].split('_')[0]) - 1, sample_files))
    
    if len(set(classes)) != 4:
        continue

    if args.dataset == "IN-1k":
        samples = [torchvision.io.read_image(sample).to(device) for sample in sample_files]
        samples_untouched = [in_1k_base_transforms(torchvision.io.read_image(sample).to(device)) for sample in sample_files]
        # print(samples_untouched[0])
    else:
        samples = [Image.open(sample).convert("RGB") for sample in sample_files]
        samples_untouched = [torch.tensor(np.array(ct_256_base_transforms(Image.open(sample).convert("RGB")))).permute(2, 0, 1).to(device) for sample in sample_files]
        # print(samples_untouched[0].shape)

    if args.dataset == "IN-1k":
        samples = [image.expand(3, -1, -1) if len(image.shape) == 2 or image.shape[0] == 1 else image for image in samples]
        samples_untouched = [image.expand(3, -1, -1) if len(image.shape) == 2 or image.shape[0] == 1 else image for image in samples_untouched]

    # samples = [model.transforms(image) for image in samples]

    for idx,image in enumerate(samples):
        top_pred, top_class = torch.topk(model(model.transforms(image).unsqueeze(0).to(device)).softmax(dim=1), k=1)
        assert top_pred.item() >= conf_threshold and top_class.item() == classes[idx], f"Top Pred: {top_pred.item()}, Top Class: {top_class.item()}, Class: {classes[idx]}"

    row1 = torch.cat((samples[0], samples[1]), dim=1)
    row2 = torch.cat((samples[2], samples[3]), dim=1)

    row1_untouched = torch.cat((samples_untouched[0], samples_untouched[1]), dim=1)
    row2_untouched = torch.cat((samples_untouched[2], samples_untouched[3]), dim=1)

    final_image_untouched = torch.cat((row1_untouched, row2_untouched), dim=2)
    final_image = model.transforms(torch.cat((row1, row2), dim=2))

    final_top_pred, final_top_class = torch.topk(model(final_image.unsqueeze(0).to(device)).softmax(dim=1), k=1)
    if final_top_class.item() not in classes:
        continue
    print("Classes: ", classes)
    print(f"Final Class: {final_top_class.item()}")
    print(f"Final Confidence: {final_top_pred.item()}")

    samples_data['file name'].append(f"file_no_{150 - num_samples}_grid_game_sample_class{final_top_class.item()}")
    samples_data['classes'].append(classes)
    samples_data['confidences'].append(final_top_pred.item())
    samples_data['final class'].append(final_top_class.item())
    
    num_samples -= 1

    plt.imshow(final_image.to("cpu").permute(1, 2, 0).numpy())
    plt.show()
    np.save(os.path.join(SAVE_DIR, f"file_no_{150 - num_samples}_grid_game_sample_class{final_top_class.item()}.npy"), final_image.to("cpu").numpy())
    print(final_image_untouched)
    torchvision.io.write_png(final_image_untouched.to("cpu"), os.path.join(SAVE_DIR, f"file_no_{150 - num_samples}_grid_game_sample_class{final_top_class.item()}.png"))
    print(f"Saved {150 - num_samples} grid game samples")

pd.DataFrame({"file name": samples_data['file name'], "classes": samples_data['classes'], "confidences": samples_data['confidences'], "final class": samples_data['final class']}).to_csv(os.path.join(SAVE_DIR, f"grid_game_samples_{args.dataset}.csv"), index=False)

