import os
import argparse
from pathlib import Path
from glob import glob
from tqdm import tqdm
import torch
from torchvision import transforms
from torchvision.transforms.functional import resize
from PIL import Image
from networks import get_model
from torch.utils.data import DataLoader, Dataset
import time
from cutonce.ncut import ncut


# Image transformation applied to all images
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)


class ImagesDataset(Dataset):
    def __init__(self, images_files, transform=None, resize=(480, 480)):
        self.images_files = images_files
        self.transform = transform
        self.resize = resize

    def __getitem__(self, index):
        image_file = self.images_files[index]

        img = Image.open(image_file).convert('RGB')
        img = img.resize(self.resize, Image.Resampling.LANCZOS)

        if self.transform is not None:
            img = self.transform(img)

        return img, image_file

    def __len__(self):
        return len(self.images_files)

    @property
    def images_size(self):
        return self.resize


def save_batch_eig_vecs(batch_files, eigenvectors, save_dir):
    for i, file in enumerate(batch_files):
        file_name = file.split("/")[-1].split(".")[0]
        torch.save(eigenvectors[i].detach().cpu().clone(), os.path.join(save_dir, f"{file_name}.pt"))


def get_images_imagenet_files(imagenet_root, split):
    print("Loading image files...")
    if split == "val":
        image_files = glob(f"{imagenet_root}/val/*.JPEG")
    elif split == "train":
        image_files = glob(f"{imagenet_root}/train/*/*.JPEG")
    else:
        raise ValueError(f"Invalid split {split} provided. Must be one of ['train', 'val']")
    return image_files


def extract_eig_vecs(img_files, models_batch_list, output_dir, device="cuda"):
    for model_name, batch_size in models_batch_list:
        # read the integer patch size from the model name
        # print(f'device: {device}')
        model, patch_size = get_model(model_name, device)
        save_dir = f"{output_dir}/{model_name}"
        Path(save_dir).mkdir(parents=True, exist_ok=True)
        if "dinov2" in model_name:
            dataset = ImagesDataset(img_files, transform=transform, resize=(476, 476))
        else:
            dataset = ImagesDataset(img_files, transform=transform)
        ts = time.time()
        model.eval()
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
        token_num = (dataset.images_size[0]//patch_size)**2
        with torch.no_grad():
            for i, (batch_images, batch_files) in enumerate(tqdm(data_loader, desc=f"Processing {model_name} model")):
                # print('batch_images',batch_images.shape)
                batch_images = batch_images.to(device)
                # to follow prior work we take the "key" features if the model is DINO, for DINOv2 we take the last layer
                if "dinov2" in model_name:
                    features = model(batch_images, return_patches=True)
                else:
                    _, k, _ = model.get_last_qkv(batch_images)
                    k = k.transpose(1, 2).reshape(batch_images.shape[0], token_num + 1, -1)
                    features = k[:, 1:, :]
                # print(model_name, features.shape)
                eigenvectors, eigenvalues = ncut(features, tau=0.15)
                # print(eigenvectors.shape)
                # temp = eigenvectors[0, :, :, 0]
                # print(temp.shape)
                # plt.imsave(fname=f"debug/attentions_extract.png", arr=temp.cpu(), cmap='cividis')
                # exit()
                
                save_batch_eig_vecs(batch_files, eigenvectors, save_dir)
                del features
        del model
        te = time.time()
        inference_time = te - ts
        print(f"Total inference time: {inference_time}, model: {model_name}")

DATASET_PATH={
    "imagenet_train": "/data/xxx/datasets/imagenet/train",
    "imagenet_val": "/data/xxx/datasets/imagenet/val",
    "coco_val2017": "/data/xxx/datasets/coco/val2017",
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Create eigenvectors for all models")
    parser.add_argument("--dataset", type=str, default="imagenet_train", choices=["imagenet_train", "imagenet_val", "coco_val2017"], help="Dataset")
    parser.add_argument("--output-dir", type=str, default="eigen_vecs/", help="Output directory for json results evaluation file")
    parser.add_argument("--device", type=str, default="cuda:1", help="computation device", choices=["cpu", "cuda"])
    # add argument of list of models to use
    parser.add_argument("--models-batch-list", nargs='+',
                        # default=[("dino_s16", 512), ("dinov2_b14", 256), ("dinov2_s14", 256), ("dino_b16", 256), ("dino_s8", 32), ("dino_b8", 16)],
                        default=[("dino_s16", 128), ("dinov2_b14", 64), ("dinov2_s14", 64), ("dino_b16", 64), ("dino_s8", 16), ("dino_b8", 12)], # on GPU 4090
                        help="List of models to use. Each model is a tuple of (model_name, batch_size)")
    args = parser.parse_args()
    print(args)

    device = args.device
    dataset_root = DATASET_PATH[args.dataset]

    if args.dataset == "imagenet_train":
        img_files = glob(f"{dataset_root}/*/*.JPEG")
    elif args.dataset == "imagenet_val":
        img_files = glob(f"{dataset_root}/*.JPEG")
    elif args.dataset == "coco_val2017":
        img_files = glob(f"{dataset_root}/*.jpg")
    else:
        raise ValueError(f"Invalid dataset: {args.dataset} provided.")
    
    models_batch_list = args.models_batch_list
    output_dir = os.path.join(args.output_dir, args.dataset)
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    models_batch_list = models_batch_list[-1:]
    print(f"models_batch_list: {models_batch_list}")
    existing_eig_vec_files = glob(f"{output_dir}/dino_b8/*.pt")

    num_files = len(img_files) - len(existing_eig_vec_files)
    if num_files == 0:
        print(f"No images left to process for {args.dataset}")
        exit(0)
    print('num_files: ', num_files)
    extract_eig_vecs(img_files, models_batch_list, output_dir, device)

    exit(0)

