import re
import os
import hydra
import numpy as np
from pathlib import Path

from omegaconf import DictConfig, OmegaConf
from local_paths import REPO_PATH
from metrics import ClipScore
from experiments_tools import update_sampler_config


@hydra.main(
    config_path=str(REPO_PATH / "configs/"),
    config_name="images",
)
def evaluation(config: DictConfig):

    update_sampler_config(config)

    print(f"Evaluating config {config.sampler.parameters}")

    imgs_save_dir = Path(config.eval.path_generated_data)
    print(f"=========== {'Clip Score'} ===========")
    clip_scorer = ClipScore(
        batch_size=config.eval.batch_size,
        model_name=config.eval.clip.model_name,
        device=config.device,
    )

    # load prompts
    # preprocess prompt to account for n_samples per prompt
    with open(config.eval.path_li_prompts, "r") as f:
        all_prompts = f.readlines()
    all_prompts = [t.strip() for t in all_prompts]
    all_prompts = np.array(all_prompts, dtype=str).repeat(
        config.eval.n_samples_per_prompt
    )
    all_prompts = all_prompts.tolist()

    # load images path
    # NOTE when saving the images, their name should be of the form
    #   prompt={prompt_index}_{sample_index}
    all_images_names = os.listdir(imgs_save_dir)

    # sort images by prompt idx
    all_images_names = sorted(
        all_images_names,
        key=lambda s: int(re.match(r"prompt=(\d+)_", s)[1]),
    )
    all_images_names = [imgs_save_dir / img_p for img_p in all_images_names]

    score = clip_scorer.compute_score(all_images_names, all_prompts)

    metrics = {
        "mean": float(score.mean()),
        "std": float(score.std()),
        "details": score.tolist(),
    }

    print(f"=========== {'Clip score:'} ===========")
    print(metrics)


if __name__ == "__main__":
    evaluation()
