"""
Generate a large dataset of text samples conditioning with true images using the i2t model.
""" 
import os
import json
import hydra
from omegaconf import DictConfig, OmegaConf
from argparse import Namespace
import math
from tqdm import tqdm
import torchvision.transforms as T

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset

from transformers import CLIPTokenizer, CLIPTextModel

from .utils_gen import load_encoders, preprocess_raw_image
from .model.cond_transformer import Transformer
from .logic.flow import MaskedSourceDistribution, UniformSourceDistribution
from .sampling_i2t import euler_sampler

from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler

from accelerate import Accelerator
from accelerate.utils import set_seed

from pycocotools.coco import COCO
import numpy as np
import einops
from PIL import Image


class MSCOCODatabase(Dataset):
    def __init__(self, root, annFile, size=None):
        self.root = root
        self.height = self.width = size

        self.coco = COCO(annFile)
        self.keys = list(sorted(self.coco.imgs.keys()))

    def _load_image(self, key: int):
        path = self.coco.loadImgs(key)[0]["file_name"]
        return Image.open(os.path.join(self.root, path)).convert("RGB")

    def _load_target(self, key: int):
        return self.coco.loadAnns(self.coco.getAnnIds(key))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, index):
        key = self.keys[index]
        image = self._load_image(key)
        image = np.array(image).astype(np.uint8)
        image = center_crop(self.width, self.height, image).astype(np.float32)
        image = (image / 127.5 - 1.0).astype(np.float32)
        image = einops.rearrange(image, 'h w c -> c h w')

        # Random select one caption
        anns = self._load_target(key)
        target = anns[np.random.randint(len(anns))]['caption']
        #target = []
        #for ann in anns:
        #    target.append(ann['caption'])

        return image, target
    
def center_crop(width, height, img):
    resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
    crop = np.min(img.shape[:2])
    img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
          (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
    try:
        img = Image.fromarray(img, 'RGB')
    except:
        img = Image.fromarray(img)
    img = img.resize((width, height), resample)

    return np.array(img).astype(np.uint8)

@hydra.main(config_path="../configs", config_name="gm_config")
def main(cfg: DictConfig):
    # -----------------------------
    # Initialize distributed environment
    # -----------------------------
    torch.backends.cuda.matmul.allow_tf32 = cfg.allow_tf32  # True: fast but may lead to some small numerical differences
    assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
    torch.set_grad_enabled(False)

    accelerator = Accelerator(mixed_precision=None)
    device = accelerator.device
    if cfg.seed is not None:
        set_seed(cfg.seed + accelerator.process_index)

    # -----------------------------
    # Load the validation dataset 
    # ----------------------------- 
    val_dataset = MSCOCODatabase(
        root="***********************",
        annFile="*********************",
        size=256
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=cfg.sampling.pproc_batch_size,
        shuffle=False,
        num_workers=cfg.dataset.dataloader.num_workers,
        pin_memory=True,
        drop_last=False
    )
    val_dataloader = accelerator.prepare(val_dataloader)

    print(f"Dataset contains {len(val_dataset):,} images.")

    # -----------------------------
    # Load the models
    # -----------------------------
    # Load Tokenizer and CLIP model 
    text_encoder = CLIPTextModel.from_pretrained(cfg.i2t.clip_path, use_safetensors=False).to(device)
    tokenizer = CLIPTokenizer.from_pretrained(cfg.i2t.clip_path)

    # Load the encoder(s) to get the latent dimension(s)
    if cfg.t2i.enc_type != None:
        encoders, encoder_types, architectures = load_encoders(
            cfg.t2i.enc_type, device, cfg.t2i.enc_path, cfg.dataset.params.size
        )
    else:
        raise NotImplementedError()
    z_dims = [encoder.embed_dim for encoder in encoders] if cfg.t2i.enc_type != 'None' else [0]

    encoder = encoders[0] 
    encoder_type = encoder_types[0] 

    # Load the image-to-text model 
    model_text = Transformer(
        config=cfg.i2t.params,
        vocab_size=tokenizer.vocab_size,
        masked=True if cfg.i2t.p0_dist == "masked" else False,
    )

    # Load checkpoints 
    ckpt_name = str(cfg.resume_step).zfill(7) +'.pt'
    ckpt = torch.load(
        f'{cfg.ckpt_dir}/checkpoints/{ckpt_name}',
        map_location='cpu',
        )
    model_text.load_state_dict(ckpt['ema_i2t'])
    del ckpt
    model_text.to(device).float()
    model_text.eval()

    # ------------------------------
    # Previous steps before sampling
    # ------------------------------
    folder_name = f"Generation-image-to-text-Steps{cfg.sampling.num_steps}"
    sample_folder_dir = os.path.join(cfg.ckpt_dir, folder_name)
    sample_folder_images_dir = os.path.join(sample_folder_dir, "images")

    os.makedirs(sample_folder_dir, exist_ok=True)
    os.makedirs(sample_folder_images_dir, exist_ok=True)

    # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
    n = cfg.sampling.pproc_batch_size
    global_batch_size = n * dist.get_world_size()
    # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
    #total_samples = 40192
    total_samples = cfg.sampling.n_samples
    if accelerator.is_main_process:
        print(f"Total number of prompts that will be sampled: {total_samples}")
        print(f"Model text Parameters: {sum(p.numel() for p in model_text.parameters()):,}")
    assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
    samples_needed_this_gpu = int(total_samples // dist.get_world_size())
    assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
    iterations = int(samples_needed_this_gpu // n)
    pbar = range(iterations)
    pbar = tqdm(pbar) if accelerator.is_main_process else pbar

    # -----------------------------
    # Sampling loop
    # -----------------------------
    # Text path 
    if cfg.i2t.prob_path == "PolynomialDiscrete":
        text_path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=1.0))
    else:
        raise NotImplementedError(f"Path {cfg.i2t.prob_path} not implemented.")

    # Source text distribution
    if cfg.i2t.p0_dist == "uniform":
        source_distribution_text = UniformSourceDistribution(vocab_size=tokenizer.vocab_size)
    elif cfg.i2t.p0_dist == "masked":
        tokenizer.add_tokens(["<MASK>"])
        text_encoder.resize_token_embeddings(len(tokenizer))
        mask_token_id = tokenizer.convert_tokens_to_ids("<MASK>")
        source_distribution_text = MaskedSourceDistribution(mask_token=mask_token_id)
    else:
        raise NotImplementedError(f"Distribution {cfg.i2t.p0_dist} not implemented.")
    prompts_tok = tokenizer(
        [""] * n, return_tensors="pt", padding="max_length", truncation=True, max_length=77
    ).input_ids

    transform = T.ToPILImage()

    # Sampling loop
    local_metadata = []
    local_image_ids = []
    j = 0
    for raw_image, _ in val_dataloader:

        raw_image = torch.tensor(raw_image, device=device)

        # Initialize the image and text tensors
        xT_text = source_distribution_text.sample_like(prompts_tok.to(device)).to(device)

        # Encoder image 
        with torch.no_grad():   
            raw_image_ = preprocess_raw_image(raw_image, encoder_type, resolution=cfg.t2i.resolution)
            z = encoder.forward_features_attn(raw_image_)['x_norm_patchtokens']

        # Perform sampling
        with torch.no_grad():
            samples_text = euler_sampler(
                text_model=model_text,
                image_embedding=z,
                initial_text=xT_text,
                vocab_size=tokenizer.vocab_size,
                path=text_path,
                num_steps=cfg.sampling.num_steps
            )
        # Decode the text samples
        samples_text = tokenizer.batch_decode(
            samples_text, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        for i, sample in enumerate(samples_text):
            index = i * accelerator.num_processes + accelerator.local_process_index + j
            path = f"{sample_folder_images_dir}/{index:06d}.png"
            img = transform(raw_image[i].cpu())
            img.save(path)

            local_metadata.append({
                "prompt": samples_text[i],
                "condition": {
                    "class_id": samples_text[i]
                },
                "class_label": None,
                "group": None,
                "dsg_questions": None,
                "dsg_children": None,
                "dsg_parents": None,
                "image_path": path
            })
        j += 1

        if accelerator.is_main_process:
            pbar.update(1)

    # === Gather metadata to rank 0 ===
    all_metadata = [None for _ in range(dist.get_world_size())]
    all_image_ids = [None for _ in range(dist.get_world_size())]

    dist.all_gather_object(all_metadata, local_metadata)
    dist.all_gather_object(all_image_ids, local_image_ids)

    # === Save metadata on rank 0 ===
    if accelerator.is_main_process:
        flat_metadata = [item for sublist in all_metadata for item in sublist]

        # Save metadata JSON
        json_path = os.path.join(sample_folder_dir, "index.json")
        with open(json_path, "w") as f:
            json.dump(flat_metadata, f, indent=4)

        print("Done.")

    # Ensure all processes wait
    dist.barrier()
    dist.destroy_process_group()


if __name__ == "__main__":
    main()