import os
import argparse
from pathlib import Path
from glob import glob
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import resize
from PIL import Image
from networks import get_model
from torch.utils.data import DataLoader, Dataset
import time
from cutonce.ncut import ncut, ncut2


# Image transformation applied to all images
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)


class ImagesDataset(Dataset):
    def __init__(self, images_files, transform=None, resize=(480, 480)):
        self.images_files = images_files
        self.transform = transform
        self.resize = resize

    def __getitem__(self, index):
        image_file = self.images_files[index]

        img = Image.open(image_file).convert('RGB')
        img = img.resize(self.resize, Image.Resampling.LANCZOS)

        if self.transform is not None:
            img = self.transform(img)

        return img, image_file

    def __len__(self):
        return len(self.images_files)

    @property
    def images_size(self):
        return self.resize


def save_batch_eig_vecs(batch_files, eigenvectors, save_dir):
    for i, file in enumerate(batch_files):
        file_name = file.split("/")[-1].split(".")[0]
        torch.save(eigenvectors[i].detach().cpu().clone(), os.path.join(save_dir, f"{file_name}.pt"))

def extract_eig_vecs(img_files, models_batch_list, output_dir, device="cuda"):
    for model_name, batch_size in models_batch_list:
        # read the integer patch size from the model name
        # print(f'device: {device}')
        model, patch_size = get_model(model_name, device)
        save_dir = f"{output_dir}/{model_name}"
        Path(save_dir).mkdir(parents=True, exist_ok=True)
        if "dinov2" in model_name:
            dataset = ImagesDataset(img_files, transform=transform, resize=(476, 476))
        else:
            dataset = ImagesDataset(img_files, transform=transform)
        ts = time.time()
        model.eval()
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
        token_num = (dataset.images_size[0]//patch_size)**2
        with torch.no_grad():
            for i, (batch_images, batch_files) in enumerate(tqdm(data_loader, desc=f"Processing {model_name} model")):
                # print('batch_images',batch_images.shape)
                batch_images = batch_images.to(device)
                # to follow prior work we take the "key" features if the model is DINO, for DINOv2 we take the last layer
                if "dinov2" in model_name:
                    features = model(batch_images, return_patches=True)
                else:
                    _, k, _ = model.get_last_qkv(batch_images)
                    k = k.transpose(1, 2).reshape(batch_images.shape[0], token_num + 1, -1)
                    features = k[:, 1:, :]
                # print(model_name, features.shape)

                # eigenvectors, eigenvalues = ncut(features, tau=0.15)
                eigenvectors, eigenvalues = ncut2(features, tau=0.15)
                # print(eigenvectors.shape)
                # temp = eigenvectors[0, :, :, 0]
                # print(temp.shape)
                # plt.imsave(fname=f"debug/attentions_extract.png", arr=temp.cpu(), cmap='cividis')
                # exit()
                
                save_batch_eig_vecs(batch_files, eigenvectors, save_dir)
                del features
        del model
        te = time.time()
        inference_time = te - ts
        print(f"Total inference time: {inference_time}, model: {model_name}")

DATASET_PATH={
    "imagenet_train": "/data/xxx/datasets/imagenet/train",
    "imagenet_val": "/data/xxx/datasets/imagenet/val",
    "coco_val2017": "/data/xxx/datasets/coco/val2017",
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Create eigenvectors for all models")
    parser.add_argument("--dataset", type=str, default="coco_val2017", choices=["imagenet_train", "imagenet_val", "coco_val2017"], help="Dataset")
    parser.add_argument("--output-dir", type=str, default="eigen_vecs_no_binary/", help="Output directory for json results evaluation file")
    parser.add_argument("--device", type=str, default="cuda:0", help="computation device", choices=["cpu", "cuda"])
    # add argument of list of models to use
    parser.add_argument("--models-batch-list", nargs='+',
                        # default=[("dino_s16", 512), ("dinov2_b14", 256), ("dinov2_s14", 256), ("dino_b16", 256), ("dino_s8", 32), ("dino_b8", 16)],
                        default=[("dino_s16", 128), ("dinov2_b14", 64), ("dinov2_s14", 64), ("dino_b16", 64), ("dino_s8", 16), ("dino_b8", 8)], # on GPU 4090
                        help="List of models to use. Each model is a tuple of (model_name, batch_size)")
    args = parser.parse_args()
    print(args)

    device = args.device
    dataset_root = DATASET_PATH[args.dataset]

    if args.dataset == "imagenet_train":
        img_files = glob(f"{dataset_root}/*/*.JPEG")
    elif args.dataset == "imagenet_val":
        img_files = glob(f"{dataset_root}/*.JPEG")
    elif args.dataset == "coco_val2017":
        img_files = glob(f"{dataset_root}/*.jpg")
    else:
        raise ValueError(f"Invalid dataset: {args.dataset} provided.")
    
    models_batch_list = args.models_batch_list
    output_dir = os.path.join(args.output_dir, args.dataset)
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # models_batch_list = models_batch_list[-1:]
    models_batch_list = models_batch_list[:-1]
    print(f"models_batch_list: {models_batch_list}")
    # exit(0)
    # existing_eig_vec_files = glob(f"{output_dir}/dino_b8/*.pt")

    # num_files = len(img_files) - len(existing_eig_vec_files)
    # if num_files == 0:
    #     print(f"No images left to process for {args.dataset}")
    #     exit(0)
    # img_files = img_files[-num_files:]
    # print('num_files: ', len(img_files))
    extract_eig_vecs(img_files, models_batch_list, output_dir, device)

    exit(0)

