import cv2
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
from datetime import timedelta
import json

import torch
import torch.distributed as dist
from torchvision.utils import save_image

# from models import fluid # video_fluid_arbitrary
from models import fluid_arbitrary_mar as fluid
from models.utils import T5_Embedding
import argparse
from pathlib import Path
from diffusers.models import AutoencoderKL
import time
from datetime import datetime

'''
CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_port 29576 -m scripts.geneval_mar
'''

def get_args_parser():
    parser = argparse.ArgumentParser('MAR training with Diffassusion Loss', add_help=False)
    parser.add_argument('--token_cache', action='store_true')
    parser.add_argument('--cfg_cache', action='store_false')
    parser.add_argument('--cal_flops', action='store_true')
    
    # Generation parameters
    parser.add_argument('--num_ar_steps', type=int, default=16, help='number of autoregressive steps')
    
    # DiSA specific parameters
    parser.add_argument('--diff_upper_steps', type=int, default=50, help='upper bound of diffusion steps')
    parser.add_argument('--diff_lower_steps', type=int, default=20, help='lower bound of diffusion steps')
    parser.add_argument('--diff_annealing_strategy', type=str, default="linear", help='diffusion annealing strategy')
    parser.add_argument('--diff_sampler', type=str, default="default", help='diffusion sampler type')
    parser.add_argument('--pivot_step_threshold', type=int, default=8, help='step threshold to start using pivot strategy')
    parser.add_argument('--pivot_diffusion_steps', type=int, default=50, help='diffusion steps for important tokens')
    parser.add_argument('--token_selection_strategy', type=str, default='pivotal', choices=['pivotal', 'random'],
                        help='token selection strategy: pivotal or random')
    parser.add_argument('--pivot_token_percentage', type=float, default=0.1, help='percentage of tokens to select as important tokens')
    
    # New parameters for order and mask strategies
    parser.add_argument('--order_strategy', type=str, default='autoregressive', choices=['autoregressive', 'random'],
                        help='order generation strategy: autoregressive or random')
    parser.add_argument('--mask_strategy', type=str, default='piecewise_cosine', choices=['cosine', 'fixed', 'piecewise_cosine'],
                        help='mask length strategy: cosine, fixed, or piecewise_cosine predefined lengths')
    
    return parser

def convert_torch_to_int(data):
    if isinstance(data, torch.Tensor):
        return int(data.item())
    elif isinstance(data, list):
        return [convert_torch_to_int(item) for item in data]
    elif isinstance(data, dict):
        return {key: convert_torch_to_int(value) for key, value in data.items()}
    else:
        return data


def default_dump(obj):
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


def main(args):
    vae_stride = 8
    patch_size = 2
    diffloss_d = 8
    diffloss_w = 1536
    model_type = 'fluid_large'
    max_length = 128
    num_ar_steps = args.num_ar_steps
    model_checkpoint_path = './models/checkpoint-512.pth'
    vae_checkpoint_path = './models/stabilityai/stable-diffusion-3.5-large'
    mt5_cache_dir = './cache/flan-t5-xxl'
    seed = 42
    cfg_scale = 4.0
    cfg_schedule = "constant"
    temperature = 1.0
    # geneval setting
    metadata_file = './prompts/evaluation_metadata.jsonl'
    args.start_step = 5
    args.fresh_t = 6
    n_samples = 4
    dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())

    output_root = "./runtime"
    # 仅主进程生成时间戳
    if dist.get_rank() == 0:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    else:
        timestamp = None

    # 广播时间戳给所有进程
    timestamp = [timestamp]
    torch.distributed.broadcast_object_list(timestamp, src=0)
    timestamp = timestamp[0]
    torch.distributed.barrier()

    # Create short names for strategies
    order_short = "ar" if args.order_strategy == "autoregressive" else "rand"
    if args.mask_strategy == "cosine":
        mask_short = "cos"
    elif args.mask_strategy == "fixed":
        mask_short = "fix"
    else:  # piecewise_cosine
        mask_short = "pcos"
    annealing_short = "lin" if args.diff_annealing_strategy == "linear" else args.diff_annealing_strategy[:3]
    
    output_dir = os.path.join(output_root,
                              f"c{args.cfg_cache}-t{args.token_cache}-steps{num_ar_steps}"
                              f"-disa-u{args.diff_upper_steps}-l{args.diff_lower_steps}-{annealing_short}"
                              f"-ord_{order_short}-mask_{mask_short}-{timestamp}")
    print(output_dir)
    if dist.get_rank() == 0:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    time.sleep(2)

    # Load prompts
    with open(metadata_file) as fp:
        metadatas = [json.loads(line) for line in fp]

    model = fluid.__dict__[model_type](
        vae_stride=vae_stride,
        patch_size=patch_size,
        vae_embed_dim=16,
        mask_ratio_min=0.7,
        text_drop_prob=0.1,
        attn_dropout=0.1,
        proj_dropout=0.1,
        diffloss_d=diffloss_d,
        diffloss_w=diffloss_w,
        max_length=128,
        token_cache=args.token_cache,
        cfg_cache=args.cfg_cache,
        # DiSA specific parameters
        diff_upper_steps=args.diff_upper_steps,
        diff_lower_steps=args.diff_lower_steps,
        diff_annealing_strategy=args.diff_annealing_strategy,
        diff_sampler=args.diff_sampler,
        pivot_step_threshold=args.pivot_step_threshold,
        pivot_diffusion_steps=args.pivot_diffusion_steps,
        token_selection_strategy=args.token_selection_strategy,
        pivot_token_percentage=args.pivot_token_percentage,
        order_strategy=args.order_strategy,
        mask_strategy=args.mask_strategy,
    ).cuda()

    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model = model.cuda()

    # download and load the vae
    
    vae = AutoencoderKL.from_pretrained(os.path.join(vae_checkpoint_path, "vae")).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False

    t5_emb = T5_Embedding(mt5_cache_dir, mt5_cache_dir, max_length).cuda()

    # set up user-specified or default values for generation
    torch.manual_seed(seed)
    np.random.seed(seed)

    model.eval()
    for index, metadata in enumerate(metadatas):
        outpath = os.path.join(output_dir, f"{index:0>5}")
        os.makedirs(outpath, exist_ok=True)
        prompt = metadata['prompt']
        # metadata['include'][0]['count'] = metadata['include'][0]['count'].item()
        metadata = convert_torch_to_int(metadata)
        with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp:
            json.dump(metadata, fp, indent=4, default=default_dump)
        caption = [prompt] * n_samples 
        # generate the tokens and images
        with torch.no_grad():
            text_emb = t5_emb(caption)
            with torch.cuda.amp.autocast():
                sampled_tokens = model.sample_tokens(
                    bsz=len(caption), num_iter=num_ar_steps,
                    cfg=cfg_scale, cfg_schedule=cfg_schedule,
                    texts=text_emb, height=512, width=512, # height=256, width=256, # 
                    temperature=temperature, progress=True,args = args
                )

                if vae.config.shift_factor is not None:
                    samples = sampled_tokens / vae.config.scaling_factor + vae.config.shift_factor
                else:
                    samples = sampled_tokens / vae.config.scaling_factor   
                output_samples = vae.decode(samples).sample

        # save the images
        image_path = os.path.join(outpath, "gird.png")
        samples_per_row = n_samples

        save_image(
            output_samples, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
        )
        output_samples = (output_samples / 2 + 0.5).clamp(0, 1)
        output_samples = output_samples.cpu().float()

        sample_path = os.path.join(outpath, "samples")
        os.makedirs(sample_path, exist_ok=True)
        for b_id in range(output_samples.size(0)):
            gen_img = np.round(np.clip(output_samples[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
            gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
            cv2.imwrite(os.path.join(sample_path, '{}.png'.format(str(b_id).zfill(4))), gen_img)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    print(f"cfg_cache: {args.cfg_cache}, token_cache: {args.token_cache}, order_strategy: {args.order_strategy}, mask_strategy: {args.mask_strategy}")
    print(f"num_ar_steps: {args.num_ar_steps}, diff_upper_steps: {args.diff_upper_steps}, diff_lower_steps: {args.diff_lower_steps}, diff_annealing_strategy: {args.diff_annealing_strategy}")
    main(args)
