import cv2
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
from datetime import timedelta
import json

import torch
import torch.distributed as dist
from torchvision.utils import save_image
from torch.utils.data import DataLoader, DistributedSampler

# from models import fluid # video_fluid_arbitrary
from models import fluid_arbitrary_mar as fluid
from models.utils import T5_Embedding
import argparse
from pathlib import Path
from diffusers.models import AutoencoderKL
from torch.utils.data import Dataset
import time
from datetime import datetime

'''
CUDA_VISIBLE_DEVICES=6,7 torchrun  --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_port 29533 -m scripts.geneval_mar
用于imagereward的图像生成。
cd ./runtime
conda activate everlyn_video
CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_port 29533 -m scripts.imagereward_mar
'''


def get_args_parser():
    parser = argparse.ArgumentParser('MAR training with Diffassusion Loss', add_help=False)
    parser.add_argument('--token_cache', action='store_true')
    parser.add_argument('--cfg_cache', action='store_true')
    parser.add_argument('--cal_flops', action='store_true')
    
    # Generation parameters
    parser.add_argument('--num_ar_steps', type=int, default=32, help='number of autoregressive steps')
    
    # DiSA specific parameters
    parser.add_argument('--diff_upper_steps', type=int, default=50, help='upper bound of diffusion steps')
    parser.add_argument('--diff_lower_steps', type=int, default=50, help='lower bound of diffusion steps')
    parser.add_argument('--diff_annealing_strategy', type=str, default="linear", help='diffusion annealing strategy')
    parser.add_argument('--diff_sampler', type=str, default="default", help='diffusion sampler type')
    parser.add_argument('--pivot_step_threshold', type=int, default=50, help='step threshold to start using pivot strategy')
    parser.add_argument('--pivot_diffusion_steps', type=int, default=50, help='diffusion steps for important tokens')
    parser.add_argument('--token_selection_strategy', type=str, default='random', choices=['pivotal', 'random'],
                        help='token selection strategy: pivotal or random')
    parser.add_argument('--pivot_token_percentage', type=float, default=0.1, help='percentage of tokens to select as important tokens')
    
    # New parameters for order and mask strategies
    parser.add_argument('--order_strategy', type=str, default='autoregressive', choices=['autoregressive', 'random'],
                        help='order generation strategy: autoregressive or random')
    parser.add_argument('--mask_strategy', type=str, default='cosine', choices=['cosine', 'fixed', 'piecewise_cosine'],
                        help='mask length strategy: cosine, fixed, or piecewise_cosine predefined lengths')
    
    return parser


def convert_torch_to_int(data):
    if isinstance(data, torch.Tensor):
        return int(data.item())
    elif isinstance(data, list):
        return [convert_torch_to_int(item) for item in data]
    elif isinstance(data, dict):
        return {key: convert_torch_to_int(value) for key, value in data.items()}
    else:
        return data


def default_dump(obj):
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


class GenEvalDataset(Dataset):
    def __init__(self, metadata_file):
        # 假设 args.metadata_file 指向一个 txt 文件，每一行代表一个 prompt
        self.metadatas = []
        with open(metadata_file, 'r', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                prompt = line.strip()
                if prompt:
                    # 这里生成一个简单的 metadata 字典，可以根据需要扩展其他字段
                    self.metadatas.append({
                        'id': f'prompt_{idx}',
                        'prompt': prompt,
                    })

    def __len__(self):
        return len(self.metadatas)

    def __getitem__(self, idx):
        metadata = self.metadatas[idx]

        return idx, metadata


def main(args):
    vae_stride = 8
    patch_size = 2
    diffloss_d = 8
    diffloss_w = 1536
    num_sampling_steps = '100'
    model_type = 'fluid_large'
    batch_size = 1
    max_length = 128
    num_ar_steps = args.num_ar_steps
    model_checkpoint_path = './models/checkpoint-512.pth'
    vae_checkpoint_path = './models/stabilityai/stable-diffusion-3.5-large'
    mt5_cache_dir = './cache/flan-t5-xxl'
    num_workers = 0
    seed = 42
    cfg_scale = 4.0
    cfg_schedule = "constant"
    temperature = 1.0
    # geneval setting
    metadata_file = './prompts/image_reward/DrawBench200.txt'
    args.start_step = 10
    args.fresh_t = 4
    n_samples = 1
    dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())

    output_root = "./runtime"
    # 仅主进程生成时间戳
    if dist.get_rank() == 0:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    else:
        timestamp = None

    # 广播时间戳给所有进程
    timestamp = [timestamp]
    torch.distributed.broadcast_object_list(timestamp, src=0)
    timestamp = timestamp[0]
    torch.distributed.barrier()

    # Create short names for strategies
    order_short = "ar" if args.order_strategy == "autoregressive" else "rand"
    if args.mask_strategy == "cosine":
        mask_short = "cos"
    elif args.mask_strategy == "fixed":
        mask_short = "fix"
    else:  # piecewise_cosine
        mask_short = "pcos"
    annealing_short = "lin" if args.diff_annealing_strategy == "linear" else args.diff_annealing_strategy[:3]
    
    output_dir = os.path.join(output_root,
                              f"imagereward-c{args.cfg_cache}-t{args.token_cache}-steps{num_ar_steps}"
                              f"-disa-u{args.diff_upper_steps}-l{args.diff_lower_steps}-{annealing_short}"
                              f"-ord_{order_short}-mask_{mask_short}-{timestamp}")

    if dist.get_rank() == 0:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    time.sleep(2)

    model = fluid.__dict__[model_type](
        vae_stride=vae_stride,
        patch_size=patch_size,
        vae_embed_dim=16,
        mask_ratio_min=0.7,
        text_drop_prob=0.1,
        attn_dropout=0.1,
        proj_dropout=0.1,
        diffloss_d=diffloss_d,
        diffloss_w=diffloss_w,
        max_length=128,
        token_cache=args.token_cache,
        cfg_cache=args.cfg_cache,
        # DiSA specific parameters
        diff_upper_steps=args.diff_upper_steps,
        diff_lower_steps=args.diff_lower_steps,
        diff_annealing_strategy=args.diff_annealing_strategy,
        diff_sampler=args.diff_sampler,
        pivot_step_threshold=args.pivot_step_threshold,
        pivot_diffusion_steps=args.pivot_diffusion_steps,
        token_selection_strategy=args.token_selection_strategy,
        pivot_token_percentage=args.pivot_token_percentage,
        order_strategy=args.order_strategy,
        mask_strategy=args.mask_strategy,
    ).cuda()

    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model = model.cuda()

    # download and load the vae

    vae = AutoencoderKL.from_pretrained(os.path.join(vae_checkpoint_path, "vae")).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False

    t5_emb = T5_Embedding(mt5_cache_dir, mt5_cache_dir, max_length).cuda()

    # set up user-specified or default values for generation
    torch.manual_seed(seed)
    np.random.seed(seed)

    dataset = GenEvalDataset(metadata_file=metadata_file)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        sampler=DistributedSampler(
            dataset,
            num_replicas=dist.get_world_size(),
            rank=dist.get_rank(),
            shuffle=False,
            drop_last=False,
        ),
    )
    model.eval()
    save_metadatas = []
    for index, metadata in tqdm(dataloader, disable=dist.get_rank() != 0):
        prompt_id = metadata['id'] = metadata['id'][0]
        prompt = metadata['prompt'] = metadata['prompt'][0]
        print(prompt_id)
        print(output_dir)
        outpath = os.path.join(output_dir, prompt_id)
        os.makedirs(outpath, exist_ok=True)
        caption = [prompt] * n_samples
        # generate the tokens and images
        with torch.no_grad():
            text_emb = t5_emb(caption)
            with torch.cuda.amp.autocast():
                sampled_tokens = model.sample_tokens(
                    bsz=len(caption), num_iter=num_ar_steps,
                    cfg=cfg_scale, cfg_schedule=cfg_schedule,
                    texts=text_emb, height=512, width=512,  # height=256, width=256, #
                    temperature=temperature, progress=True, args=args
                )

                if vae.config.shift_factor is not None:
                    samples = sampled_tokens / vae.config.scaling_factor + vae.config.shift_factor
                else:
                    samples = sampled_tokens / vae.config.scaling_factor
                output_samples = vae.decode(samples).sample

        output_samples = (output_samples / 2 + 0.5).clamp(0, 1)
        output_samples = output_samples.cpu().float()
        metadata['gen_image_paths'] = []
        for b_id in range(output_samples.size(0)):
            gen_img = np.round(np.clip(output_samples[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
            gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
            sample_path = os.path.join(outpath, f"{prompt_id}_{b_id}.jpg")
            cv2.imwrite(sample_path, gen_img)
            metadata['gen_image_paths'].append(sample_path)
        save_metadatas.append(metadata)
    save_metadata_file_path = os.path.join(output_dir, "metadata.jsonl")
    with open(save_metadata_file_path, "w") as fp:
        json.dump(save_metadatas, fp)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    print(f"cfg_cache: {args.cfg_cache}, token_cache: {args.token_cache}, order_strategy: {args.order_strategy}, mask_strategy: {args.mask_strategy}")
    print(f"num_ar_steps: {args.num_ar_steps}, diff_upper_steps: {args.diff_upper_steps}, diff_lower_steps: {args.diff_lower_steps}, diff_annealing_strategy: {args.diff_annealing_strategy}")
    main(args)
