"""
Generate a large dataset of samples using the Generator Matching (GM) model.
""" 
import os
import json
import hydra
from omegaconf import DictConfig, OmegaConf
from argparse import Namespace
import math
from tqdm import tqdm
import torchvision.transforms as T

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset

from transformers import CLIPTokenizer, CLIPTextModel

from .utils_gen import load_encoders, preprocess_raw_image
from .model.cond_transformer import Transformer
from .logic.flow import MaskedSourceDistribution, UniformSourceDistribution
from .sampling_i2t import euler_sampler

from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler

from pycocotools.coco import COCO
import numpy as np
import einops
from PIL import Image


def center_crop(width, height, img):
    resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
    crop = np.min(img.shape[:2])
    img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
          (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
    try:
        img = Image.fromarray(img, 'RGB')
    except:
        img = Image.fromarray(img)
    img = img.resize((width, height), resample)

    return np.array(img).astype(np.uint8)

class MSCOCODatabase(Dataset):
    def __init__(self, root, annFile, size=None):
        self.root = root
        self.height = self.width = size

        self.coco = COCO(annFile)
        self.keys = list(sorted(self.coco.imgs.keys()))

    def _load_image(self, key: int):
        path = self.coco.loadImgs(key)[0]["file_name"]
        return Image.open(os.path.join(self.root, path)).convert("RGB")

    def _load_target(self, key: int):
        return self.coco.loadAnns(self.coco.getAnnIds(key))

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, index):
        key = self.keys[index]
        image = self._load_image(key)
        image = np.array(image).astype(np.uint8)
        image = center_crop(self.width, self.height, image).astype(np.float32)
        image = (image / 127.5 - 1.0).astype(np.float32)
        image = einops.rearrange(image, 'h w c -> c h w')

        # Random select one caption
        anns = self._load_target(key)
        target = anns[np.random.randint(len(anns))]['caption']
        #target = []
        #for ann in anns:
        #    target.append(ann['caption'])

        return image, target

def recursive_namespace(d):
    """
    Recursively converts a dict into an argparse.Namespace
    """
    if isinstance(d, dict):
        return Namespace(**{k: recursive_namespace(v) for k, v in d.items()})
    else:
        return d

def dictconfig_to_namespace(cfg: DictConfig) -> Namespace:
    cfg_dict = OmegaConf.to_container(cfg, resolve=True)
    return recursive_namespace(cfg_dict)

@hydra.main(config_path="../configs", config_name="gm_config")
def main(cfg: DictConfig):

    args = dictconfig_to_namespace(cfg)

    # -----------------------------
    # Initialize distributed environment
    # -----------------------------
    torch.backends.cuda.matmul.allow_tf32 = args.allow_tf32  # True: fast but may lead to some small numerical differences
    assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
    torch.set_grad_enabled(False)

    dist.init_process_group("nccl")
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = (args.seed + rank) * dist.get_world_size()//2
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
    print("Device:", device)

    if args.accelerator.mixed_precision == "fp16":
        dtype_all = torch.float16
    else:
        dtype_all = torch.float32

    # -----------------------------
    # Load the models
    # -----------------------------
    # Load Tokenizer and CLIP model 
    text_encoder = CLIPTextModel.from_pretrained(args.i2t.clip_path, use_safetensors=False).to(device)
    tokenizer = CLIPTokenizer.from_pretrained(args.i2t.clip_path)

    # Load the encoder(s) to get the latent dimension(s)
    if args.t2i.enc_type != None:
        encoders, encoder_types, architectures = load_encoders(
            args.t2i.enc_type, device, args.t2i.enc_path, args.dataset.params.size
        )
    else:
        raise NotImplementedError()
    z_dims = [encoder.embed_dim for encoder in encoders] if args.t2i.enc_type != 'None' else [0]

    encoder = encoders[0] 
    encoder_type = encoder_types[0] 

    # Load the image-to-text model 
    model_text = Transformer(
        config=args.i2t.params,
        vocab_size=tokenizer.vocab_size,
        masked=True if args.i2t.p0_dist == "masked" else False,
    )

    # Load checkpoints 
    ckpt_name = str(args.resume_step).zfill(7) +'.pt'
    ckpt = torch.load(
        f'{args.ckpt_dir}/checkpoints/{ckpt_name}',
        map_location='cpu',
        )
    model_text.load_state_dict(ckpt['ema_i2t'])
    del ckpt_name
    model_text.to(device).float()
    model_text.eval()

    # -----------------------------
    # Load the validation set 
    # -----------------------------
    val_dataset = MSCOCODatabase(
        root="*********************************",
        annFile="*************************************",
        size=args.t2i.resolution
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=args.sampling.pproc_batch_size,
        shuffle=False,
        num_workers=args.dataset.dataloader.num_workers,
        pin_memory=True,
        drop_last=False
    )

    print(f"Dataset contains {len(val_dataset):,} images.")

    # ------------------------------
    # Previous steps before sampling
    # ------------------------------
    folder_name = f"Generation-image-to-text-Steps{args.sampling.num_steps}"
    sample_folder_dir = os.path.join(args.ckpt_dir, folder_name)
    sample_folder_images_dir = os.path.join(sample_folder_dir, "images")
    skip = torch.tensor([False], device=device)
    if rank == 0:
        if os.path.exists(f"{sample_folder_dir}/{folder_name}.npz"):
            skip[0] = True
            print(f"Skipping sampling as {sample_folder_dir}.npz already exists.")
        else:
            os.makedirs(sample_folder_dir, exist_ok=True)
            os.makedirs(sample_folder_images_dir, exist_ok=True)
            print(f"Saving .png samples at {sample_folder_dir}")
    
    # Broadcast the skip flag to all processes
    dist.broadcast(skip, src=0)
    if skip.item():
        dist.destroy_process_group()
        return
    dist.barrier()

    # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
    n = args.sampling.pproc_batch_size
    global_batch_size = n * dist.get_world_size()
    # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
    total_samples = int(math.ceil(args.sampling.n_samples / global_batch_size) * global_batch_size)
    if rank == 0:
        print(f"Total number of images that will be sampled: {total_samples}")
        print(f"Text Model Parameters: {sum(p.numel() for p in model_text.parameters()):,}")
    assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
    samples_needed_this_gpu = int(total_samples // dist.get_world_size())
    assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
    iterations = int(samples_needed_this_gpu // n)
    pbar = range(iterations)
    pbar = tqdm(pbar) if rank == 0 else pbar
    
    # -----------------------------
    # Sampling loop
    # -----------------------------
    # Text path 
    if args.i2t.prob_path == "PolynomialDiscrete":
        text_path = MixtureDiscreteProbPath(scheduler=PolynomialConvexScheduler(n=1.0))
    else:
        raise NotImplementedError(f"Path {args.i2t.prob_path} not implemented.")

    # Source text distribution
    if args.i2t.p0_dist == "uniform":
        source_distribution_text = UniformSourceDistribution(vocab_size=tokenizer.vocab_size)
    elif args.i2t.p0_dist == "masked":
        tokenizer.add_tokens(["<MASK>"])
        text_encoder.resize_token_embeddings(len(tokenizer))
        mask_token_id = tokenizer.convert_tokens_to_ids("<MASK>")
        source_distribution_text = MaskedSourceDistribution(mask_token=mask_token_id)
    else:
        raise NotImplementedError(f"Distribution {args.i2t.p0_dist} not implemented.")
    prompts_tok = tokenizer(
        [""] * n, return_tensors="pt", padding="max_length", truncation=True, max_length=77
    ).input_ids

    # Transform tensor to PIL Image 
    transform = T.ToPILImage()

    # Sampling loop
    local_metadata = []
    local_image_ids = []
    j = 0
    for _ in pbar:
        # Get the next batch of data
        data = next(iter(val_dataloader))
        raw_image, real_caption = data

        print(raw_image.shape, raw_image.dtype, raw_image.min(), raw_image.max())

        raw_image = (raw_image + 1) / 2.
        raw_image = raw_image * 255.
        raw_image = raw_image.to(torch.uint8).to(device)

        print(raw_image.shape, raw_image.dtype, raw_image.min(), raw_image.max())

        # Initialize the image and text tensors
        xT_text = source_distribution_text.sample_like(prompts_tok.to(device)).to(device)

        # Extract image features 
        with torch.no_grad():   
            raw_image_ = preprocess_raw_image(raw_image, encoder_type, resolution=args.t2i.resolution).to(device)
            z = encoder.forward_features_attn(raw_image_)['x_norm_patchtokens']

        # Perform sampling
        with torch.no_grad():
            samples_text = euler_sampler(
                text_model=model_text,
                image_embedding=z,
                initial_text=xT_text,
                vocab_size=tokenizer.vocab_size,
                path=text_path,
                num_steps=cfg.sampling.num_steps
            )

        # Decode the text samples
        samples_text = tokenizer.batch_decode(
            samples_text, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        for i, sample in enumerate(samples_text):
            img_name = f"image_{j:04d}_{rank * n + i:06d}.png"
            img_path = f"{sample_folder_images_dir}/{img_name}"
            sample = transform(raw_image[i].cpu())
            sample.save(img_path)

            local_image_ids.append(img_name)

            local_metadata.append({
                "real_prompt": real_caption[i],
                "predicted_prompt": samples_text[i],
                "image_path": os.path.abspath(img_path)
            })
        j += 1

    # === Gather metadata to rank 0 ===
    all_metadata = [None for _ in range(dist.get_world_size())]
    all_image_ids = [None for _ in range(dist.get_world_size())]

    dist.all_gather_object(all_metadata, local_metadata)
    dist.all_gather_object(all_image_ids, local_image_ids)

    # === Save metadata on rank 0 ===
    if rank == 0:
        flat_metadata = [item for sublist in all_metadata for item in sublist]
        flat_image_ids = [item for sublist in all_image_ids for item in sublist]

        # Save metadata JSON
        json_path = os.path.join(sample_folder_dir, "index.json")
        with open(json_path, "w") as f:
            json.dump(flat_metadata, f, indent=4)

        print("Done.")

    # Ensure all processes wait
    dist.barrier()
    dist.destroy_process_group()


if __name__ == "__main__":
    main()