import hydra
from omegaconf import DictConfig, OmegaConf
from argparse import Namespace

import os
import json
import torch
import logging
import numpy as np

from sentence_transformers import SentenceTransformer, util

import torch
import os

from pycocotools.coco import COCO
from torch.utils.data import Dataset
import einops
from PIL import Image

def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    logging.basicConfig(
        level=logging.INFO,
        format='[\033[34m%(asctime)s\033[0m] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
    )
    logger = logging.getLogger(__name__)
    return logger

def normalize_to_neg1_pos1(tensor):
    return tensor * 2.0 - 1.0

def recursive_namespace(d):
    """
    Recursively converts a dict into an argparse.Namespace
    """
    if isinstance(d, dict):
        return Namespace(**{k: recursive_namespace(v) for k, v in d.items()})
    else:
        return d

def dictconfig_to_namespace(cfg: DictConfig) -> Namespace:
    cfg_dict = OmegaConf.to_container(cfg, resolve=True)
    return recursive_namespace(cfg_dict)

def center_crop(width, height, img):
    resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
    crop = np.min(img.shape[:2])
    img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
          (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
    try:
        img = Image.fromarray(img, 'RGB')
    except:
        img = Image.fromarray(img)
    img = img.resize((width, height), resample)

    return np.array(img).astype(np.uint8)

class MSCOCODatabase(Dataset):
    def __init__(self, root, annFile, size=None):
        self.root = root
        self.height = self.width = size

        self.coco = COCO(annFile)
        self.keys = list(sorted(self.coco.imgs.keys()))

    def _load_image(self, key: int):
        path = self.coco.loadImgs(key)[0]["file_name"]
        return Image.open(os.path.join(self.root, path)).convert("RGB")

    def _load_target(self, key: int):
        return self.coco.loadAnns(self.coco.getAnnIds(key))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, index):
        key = self.keys[index]
        image = self._load_image(key)
        image = np.array(image).astype(np.uint8)
        image = center_crop(self.width, self.height, image).astype(np.float32)
        image = (image / 127.5 - 1.0).astype(np.float32)
        image = einops.rearrange(image, 'h w c -> c h w')

        # Random select one caption
        anns = self._load_target(key)
        target = anns[np.random.randint(len(anns))]['caption']
        #target = []
        #for ann in anns:
        #    target.append(ann['caption'])

        return image, target

@hydra.main(config_path="configs", config_name="likelihood_bench")
def main(cfg: DictConfig):
    args = dictconfig_to_namespace(cfg)

    os.makedirs("logs", exist_ok=True)

    # ------------------------------------
    # set up experiments directory
    # ------------------------------------
    os.makedirs(args.output_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
    save_dir = os.path.join(args.output_dir, args.exp_name)
    os.makedirs(save_dir, exist_ok=True)
    models_dir = f"{save_dir}/models"
    os.makedirs(models_dir, exist_ok=True)
    logger = create_logger(save_dir)
    logger.info(f"Experiment directory created at {save_dir}")

    tag = os.environ.get("GIT_COMMIT_SHORT")
    logging.info(f"Tag: {tag}")

    # ------------------------------------
    # Setup data:
    # ------------------------------------
    index_path = '/your/path/to/generated_captions.json'
    with open(index_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    val_captions = [item['real_prompt'] for item in data]
    gen_captions = [item['predicted_prompt'] for item in data]

    # Load a sentence embedding model
    embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

    # Compute embeddings
    val_embeddings = embedder.encode(val_captions,
                                     batch_size=64,
                                     convert_to_numpy=True,
                                     show_progress_bar=True,
                                     device='cuda')
    gen_embeddings = embedder.encode(gen_captions,
                                     batch_size=64,
                                     convert_to_numpy=True,
                                     show_progress_bar=True,
                                     device='cuda')

    # Compute distances
    similarities = util.pairwise_cos_sim(val_embeddings, gen_embeddings)

    # Compute statistics
    print("Cosine distance statistics between real and synthetic captions:")
    print(f"  Min:    {torch.min(similarities).item():.4f}")
    print(f"  Max:    {torch.max(similarities).item():.4f}")
    print(f"  Mean:   {torch.mean(similarities).item():.4f}")
    print(f"  Median: {torch.median(similarities).item():.4f}")
    print(f"  Std:    {torch.std(similarities).item():.4f}")

    output = ""
    for i, similarity in enumerate(similarities):
        print(f"Captions {val_captions[i]} // {gen_captions[i]} // have cos. sim.: {similarity:.4f}")
        output += str(similarity) + "\n"

    file_path = f"{models_dir}/results"
    with open(f"{file_path}.txt", "w", encoding="utf-8") as f:
        f.write(output)
        
    print("Done! Bye.")


if __name__ == "__main__":
    main()
