import cv2
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
from datetime import timedelta
import json

import torch
import torch.distributed as dist
from torchvision.utils import save_image

# from models import fluid # video_fluid_arbitrary
from models import fluid_arbitrary_mar as fluid
from models.utils import T5_Embedding
import argparse
from pathlib import Path
from diffusers.models import AutoencoderKL
import time
from datetime import datetime

'''
CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_port 29578 -m scripts.pipeline_geneval
'''

def get_args_parser():
    parser = argparse.ArgumentParser('Pipeline prompts generation with Diffusion Loss', add_help=False)
    parser.add_argument('--token_cache', action='store_true')
    parser.add_argument('--cfg_cache', action='store_true')
    parser.add_argument('--cal_flops', action='store_true')
    
    # Generation parameters
    parser.add_argument('--num_ar_steps', type=int, default=32, help='number of autoregressive steps')
    parser.add_argument('--k_samples', type=int, default=12, help='number of images to generate per prompt')
    
    # DiSA specific parameters
    parser.add_argument('--diff_upper_steps', type=int, default=100, help='upper bound of diffusion steps')
    parser.add_argument('--diff_lower_steps', type=int, default=100, help='lower bound of diffusion steps')
    parser.add_argument('--diff_annealing_strategy', type=str, default="linear", help='diffusion annealing strategy')
    parser.add_argument('--diff_sampler', type=str, default="default", help='diffusion sampler type')
    parser.add_argument('--pivot_step_threshold', type=int, default=100, help='step threshold to start using pivot strategy')
    parser.add_argument('--pivot_diffusion_steps', type=int, default=100, help='diffusion steps for important tokens')
    parser.add_argument('--token_selection_strategy', type=str, default='random', choices=['pivotal', 'random'],
                        help='token selection strategy: pivotal or random')
    parser.add_argument('--pivot_token_percentage', type=float, default=0.1, help='percentage of tokens to select as important tokens')
    
    # New parameters for order and mask strategies
    parser.add_argument('--order_strategy', type=str, default='autoregressive', choices=['autoregressive', 'random'],
                        help='order generation strategy: autoregressive or random')
    parser.add_argument('--mask_strategy', type=str, default='cosine', choices=['cosine', 'fixed', 'piecewise_cosine'],
                        help='mask length strategy: cosine, fixed, or piecewise_cosine predefined lengths')
    
    # Pipeline prompts file path
    parser.add_argument('--prompts_file', type=str, default='./prompts/prompts/pipeline_prompts.txt',
                        help='path to pipeline prompts txt file')
    
    return parser

def convert_torch_to_int(data):
    if isinstance(data, torch.Tensor):
        return int(data.item())
    elif isinstance(data, list):
        return [convert_torch_to_int(item) for item in data]
    elif isinstance(data, dict):
        return {key: convert_torch_to_int(value) for key, value in data.items()}
    else:
        return data


def default_dump(obj):
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


def load_pipeline_prompts(prompts_file):
    """Load prompts from pipeline_prompts.txt file"""
    prompts = []
    with open(prompts_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f):
            line = line.strip()
            if line:  # Skip empty lines
                # Remove surrounding quotes if present
                if line.startswith('"') and line.endswith('"'):
                    line = line[1:-1]
                
                # Create metadata structure similar to evaluation_metadata.jsonl
                metadata = {
                    'prompt': line,
                    'index': line_num,
                    'source': 'pipeline_prompts.txt'
                }
                prompts.append(metadata)
    return prompts


def main(args):
    vae_stride = 8
    patch_size = 2
    diffloss_d = 8
    diffloss_w = 1536
    model_type = 'fluid_large'
    max_length = 128
    num_ar_steps = args.num_ar_steps
    k_samples = args.k_samples
    model_checkpoint_path = './models/checkpoint-512.pth'
    vae_checkpoint_path = './models/stabilityai/stable-diffusion-3.5-large'
    mt5_cache_dir = './cache/flan-t5-xxl'
    seed = 42
    cfg_scale = 4.0
    cfg_schedule = "constant"
    temperature = 1.0
    
    # Pipeline prompts setting
    args.start_step = 5
    args.fresh_t = 6
    
    dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())

    output_root = "./runtime"
    # 仅主进程生成时间戳
    if dist.get_rank() == 0:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    else:
        timestamp = None

    # 广播时间戳给所有进程
    timestamp = [timestamp]
    torch.distributed.broadcast_object_list(timestamp, src=0)
    timestamp = timestamp[0]
    torch.distributed.barrier()

    # Create short names for strategies
    order_short = "ar" if args.order_strategy == "autoregressive" else "rand"
    if args.mask_strategy == "cosine":
        mask_short = "cos"
    elif args.mask_strategy == "fixed":
        mask_short = "fix"
    else:  # piecewise_cosine
        mask_short = "pcos"
    annealing_short = "lin" if args.diff_annealing_strategy == "linear" else args.diff_annealing_strategy[:3]
    
    output_dir = os.path.join(output_root,
                              f"pipeline-c{args.cfg_cache}-t{args.token_cache}-steps{num_ar_steps}"
                              f"-k{k_samples}-disa-u{args.diff_upper_steps}-l{args.diff_lower_steps}-{annealing_short}"
                              f"-ord_{order_short}-mask_{mask_short}-{timestamp}")
    print(f"Output directory: {output_dir}")
    if dist.get_rank() == 0:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    time.sleep(2)

    # Load prompts from pipeline_prompts.txt
    if dist.get_rank() == 0:
        print(f"Loading prompts from: {args.prompts_file}")
    metadatas = load_pipeline_prompts(args.prompts_file)
    if dist.get_rank() == 0:
        print(f"Loaded {len(metadatas)} prompts")

    model = fluid.__dict__[model_type](
        vae_stride=vae_stride,
        patch_size=patch_size,
        vae_embed_dim=16,
        mask_ratio_min=0.7,
        text_drop_prob=0.1,
        attn_dropout=0.1,
        proj_dropout=0.1,
        diffloss_d=diffloss_d,
        diffloss_w=diffloss_w,
        max_length=128,
        token_cache=args.token_cache,
        cfg_cache=args.cfg_cache,
        # DiSA specific parameters
        diff_upper_steps=args.diff_upper_steps,
        diff_lower_steps=args.diff_lower_steps,
        diff_annealing_strategy=args.diff_annealing_strategy,
        diff_sampler=args.diff_sampler,
        pivot_step_threshold=args.pivot_step_threshold,
        pivot_diffusion_steps=args.pivot_diffusion_steps,
        token_selection_strategy=args.token_selection_strategy,
        pivot_token_percentage=args.pivot_token_percentage,
        order_strategy=args.order_strategy,
        mask_strategy=args.mask_strategy,
    ).cuda()

    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model = model.cuda()

    # download and load the vae
    vae = AutoencoderKL.from_pretrained(os.path.join(vae_checkpoint_path, "vae")).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False

    t5_emb = T5_Embedding(mt5_cache_dir, mt5_cache_dir, max_length).cuda()

    # set up user-specified or default values for generation
    torch.manual_seed(seed)
    np.random.seed(seed)

    model.eval()
    
    # Add progress bar for main process
    if dist.get_rank() == 0:
        metadatas = tqdm(metadatas, desc="Generating images")
    
    for index, metadata in enumerate(metadatas):
        outpath = os.path.join(output_dir, f"{index:0>5}")
        os.makedirs(outpath, exist_ok=True)
        prompt = metadata['prompt']
        
        # Convert metadata for JSON serialization
        metadata = convert_torch_to_int(metadata)
        metadata['k_samples'] = k_samples
        metadata['num_ar_steps'] = num_ar_steps
        metadata['cfg_scale'] = cfg_scale
        metadata['seed'] = seed
        
        with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp:
            json.dump(metadata, fp, indent=4, default=default_dump)
        
        caption = [prompt] * k_samples 
        
        # generate the tokens and images
        with torch.no_grad():
            text_emb = t5_emb(caption)
            with torch.cuda.amp.autocast():
                sampled_tokens = model.sample_tokens(
                    bsz=len(caption), num_iter=num_ar_steps,
                    cfg=cfg_scale, cfg_schedule=cfg_schedule,
                    texts=text_emb, height=512, width=512,
                    temperature=temperature, progress=False, args=args
                )

                if vae.config.shift_factor is not None:
                    samples = sampled_tokens / vae.config.scaling_factor + vae.config.shift_factor
                else:
                    samples = sampled_tokens / vae.config.scaling_factor   
                output_samples = vae.decode(samples).sample

        # save the images
        image_path = os.path.join(outpath, "grid.png")
        samples_per_row = k_samples

        save_image(
            output_samples, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
        )
        output_samples = (output_samples / 2 + 0.5).clamp(0, 1)
        output_samples = output_samples.cpu().float()

        sample_path = os.path.join(outpath, "samples")
        os.makedirs(sample_path, exist_ok=True)
        for b_id in range(output_samples.size(0)):
            gen_img = np.round(np.clip(output_samples[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
            gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
            cv2.imwrite(os.path.join(sample_path, '{}.png'.format(str(b_id).zfill(4))), gen_img)

    if dist.get_rank() == 0:
        print(f"Generation completed! Results saved to: {output_dir}")

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    print(f"cfg_cache: {args.cfg_cache}, token_cache: {args.token_cache}, order_strategy: {args.order_strategy}, mask_strategy: {args.mask_strategy}")
    print(f"num_ar_steps: {args.num_ar_steps}, k_samples: {args.k_samples}, diff_upper_steps: {args.diff_upper_steps}, diff_lower_steps: {args.diff_lower_steps}, diff_annealing_strategy: {args.diff_annealing_strategy}")
    main(args)
