import os
import sys
from pathlib import Path
import pickle
import numpy as np
import shutil
from tqdm import tqdm

import torch
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
from PIL import Image

HOME_PATH = "../../metadata/cc12m/"
SAVE_PATH = "../../outputs/evaluations/aesthetic/"


class AestheticDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, preprocessor, complexity):
        self.image_folder = Path(image_folder)
        if "generations" in str(self.image_folder):
            self.image_paths = [
                os.path.join(self.image_folder, f"{img_idx}.png")
                for img_idx in range(100000)
            ]
        elif "eval" in str(self.image_folder):
            dict_path = os.path.join(
                HOME_PATH,
                "full_dict_gemma3_eval_clean_siglip_real_5k_4caps.pkl"
            )
            with open(dict_path, "rb") as f:
                eval_set_dict = pickle.load(f)
            self.image_names = list(eval_set_dict[complexity].keys())
            self.image_paths = [
                os.path.join(self.image_folder, f"{img_name}.jpg")
                for img_name in self.image_names
            ]
        else:
            raise ValueError("Invalid image folder path.")
        self.preprocessor = preprocessor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        image = image.resize((256, 256))
        pixel_values = self.preprocessor(
            images=image,
            return_tensors="pt"
        ).pixel_values
        return pixel_values.squeeze(0)


def get_aesthetic_score(eval_setting, complexity, save_folder):
    # Path to the image folder containing images to evaluate
    ROOT_PATH = f"/tmp/jobid_{os.environ.get('SLURM_JOB_ID', '0')}/"
    if "DATA" in eval_setting:
        SAMPLE_IMAGE_PATH = os.path.join(ROOT_PATH, "eval")
    else:
        SAMPLE_IMAGE_PATH = os.path.join(ROOT_PATH, "generations")

    # load model and preprocessor
    model, preprocessor = convert_v2_5_from_siglip(
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    model = model.to(torch.bfloat16).cuda()

    dataset = AestheticDataset(SAMPLE_IMAGE_PATH, preprocessor, complexity)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=2048,
        shuffle=False,
        drop_last=False,
        num_workers=10,
    )

    # predict aesthetic score
    with torch.inference_mode():
        scores = []
        for pixel_values in tqdm(dataloader):
            pixel_values = pixel_values.to(torch.bfloat16).cuda()
            # predict
            outputs = model(pixel_values)
            logits = outputs.logits
            score = logits.squeeze().float().cpu().numpy().tolist()
            scores.extend(score)

    np.save(os.path.join(ROOT_PATH, f"aes_{eval_setting}_scores.npy"), scores)
    os.makedirs(os.path.join(SAVE_PATH, save_folder), exist_ok=True)
    src = os.path.join(ROOT_PATH, f"aes_{eval_setting}_scores.npy")
    dst = os.path.join(SAVE_PATH, save_folder)
    shutil.copy(src, dst)

    # print result
    print(f"Aesthetics score: {scores}")


if __name__ == "__main__":
    eval_setting = sys.argv[-3]
    complexity = int(sys.argv[-2])
    save_folder = sys.argv[-1]
    get_aesthetic_score(eval_setting, complexity, save_folder)
