import cv2
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
from datetime import timedelta
import json

import torch
import torch.distributed as dist
from torchvision.utils import save_image

# from models import fluid # video_fluid_arbitrary
from models import fluid_arbitrary_mar_checkboard as fluid
from models.utils import T5_Embedding
import argparse
from pathlib import Path
from diffusers.models import AutoencoderKL
import time
from datetime import datetime

'''
CUDA_VISIBLE_DEVICES=6,7 torchrun  --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_port 29533 -m scripts.geneval_mar_checkboard
'''

def get_args_parser():
    parser = argparse.ArgumentParser('MAR training with Diffassusion Loss', add_help=False)
    # New parameters for order and mask strategies
    parser.add_argument('--order_strategy', type=str, default='autoregressive', choices=['autoregressive', 'random'],
                        help='order generation strategy: autoregressive or random')
    parser.add_argument('--mask_strategy', type=str, default='fixed', choices=['cosine', 'fixed'],
                        help='mask length strategy: cosine or fixed predefined lengths')
    return parser

def convert_torch_to_int(data):
    if isinstance(data, torch.Tensor):
        return int(data.item())
    elif isinstance(data, list):
        return [convert_torch_to_int(item) for item in data]
    elif isinstance(data, dict):
        return {key: convert_torch_to_int(value) for key, value in data.items()}
    else:
        return data


def default_dump(obj):
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


def main(args):
    vae_stride = 8
    patch_size = 2
    diffloss_d = 8
    diffloss_w = 1536
    num_sampling_steps = '100'
    model_type = 'fluid_large'
    batch_size = 1
    max_length = 128
    num_ar_steps = 32
    model_checkpoint_path = './models/checkpoint-512.pth'
    vae_checkpoint_path = './models/stabilityai/stable-diffusion-3.5-large'
    mt5_cache_dir = './cache/flan-t5-xxl'
    num_workers = 0
    seed = 42
    cfg_scale = 4.0
    cfg_schedule = "constant"
    temperature = 1.0
    # geneval setting
    metadata_file = './prompts/evaluation_metadata.jsonl'
    n_samples = 4
    dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())

    output_root = "./runtime"
    # 仅主进程生成时间戳
    if dist.get_rank() == 0:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    else:
        timestamp = None

    # 广播时间戳给所有进程
    timestamp = [timestamp]
    torch.distributed.broadcast_object_list(timestamp, src=0)
    timestamp = timestamp[0]
    torch.distributed.barrier()

    # Create short names for strategies
    order_short = "ar" if args.order_strategy == "autoregressive" else "rand"
    mask_short = "cos" if args.mask_strategy == "cosine" else "fix"
    
    output_dir = os.path.join(output_root,
                              f"steps{num_ar_steps}-maxlen{max_length}-ord_{order_short}-mask_{mask_short}-b{batch_size}-{timestamp}")

    if dist.get_rank() == 0:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    time.sleep(2)

    # Load prompts
    with open(metadata_file) as fp:
        metadatas = [json.loads(line) for line in fp]

    model = fluid.__dict__[model_type](
        vae_stride=vae_stride,
        patch_size=patch_size,
        vae_embed_dim=16,
        mask_ratio_min=0.7,
        text_drop_prob=0.1,
        attn_dropout=0.1,
        proj_dropout=0.1,
        diffloss_d=diffloss_d,
        diffloss_w=diffloss_w,
        max_length=128,
        num_sampling_steps=num_sampling_steps,
        order_strategy=args.order_strategy,
        mask_strategy=args.mask_strategy,
    ).cuda()

    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model = model.cuda()

    # download and load the vae
    
    vae = AutoencoderKL.from_pretrained(os.path.join(vae_checkpoint_path, "vae")).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False

    t5_emb = T5_Embedding(mt5_cache_dir, mt5_cache_dir, max_length).cuda()

    # set up user-specified or default values for generation
    torch.manual_seed(seed)
    np.random.seed(seed)

    model.eval()
    for index, metadata in enumerate(metadatas):
        outpath = os.path.join(output_dir, f"{index:0>5}")
        os.makedirs(outpath, exist_ok=True)
        prompt = metadata['prompt']
        # metadata['include'][0]['count'] = metadata['include'][0]['count'].item()
        metadata = convert_torch_to_int(metadata)
        with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp:
            json.dump(metadata, fp, indent=4, default=default_dump)
        caption = [prompt] * 4
        # generate the tokens and images
        with torch.no_grad():
            text_emb = t5_emb(caption)
            with torch.cuda.amp.autocast():
                sampled_tokens = model.sample_tokens(
                    bsz=len(caption), num_iter=num_ar_steps,
                    cfg=cfg_scale, cfg_schedule=cfg_schedule,
                    texts=text_emb, height=512, width=512,  # height=256, width=256, #
                    temperature=temperature, progress=True,
                )

                if vae.config.shift_factor is not None:
                    samples = sampled_tokens / vae.config.scaling_factor + vae.config.shift_factor
                else:
                    samples = sampled_tokens / vae.config.scaling_factor   
                output_samples = vae.decode(samples).sample

        # save the images
        image_path = os.path.join(outpath, "gird.png")
        samples_per_row = n_samples

        save_image(
            output_samples, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
        )
        output_samples = (output_samples / 2 + 0.5).clamp(0, 1)
        output_samples = output_samples.cpu().float()

        sample_path = os.path.join(outpath, "samples")
        os.makedirs(sample_path, exist_ok=True)
        for b_id in range(output_samples.size(0)):
            gen_img = np.round(np.clip(output_samples[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
            gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
            cv2.imwrite(os.path.join(sample_path, '{}.png'.format(str(b_id).zfill(4))), gen_img)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    print(f"order_strategy: {args.order_strategy}, mask_strategy: {args.mask_strategy}")
    main(args)
