import torch
import numpy as np
import os
from torchvision.utils import save_image
from models import fluid_arbitrary_mar as fluid
from models.utils import T5_Embedding
from diffusers.models import AutoencoderKL
import argparse


'''
CUDA_VISIBLE_DEVICES=1 python pipeline_image_mar.py
'''



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_args_parser():
    parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
    parser.add_argument('--token_cache', action='store_true')
    parser.add_argument('--cfg_cache', action='store_true')
    parser.add_argument('--cal_flops', action='store_true')
    return parser

def main(args):
    # init the model architecture
    # img_size = 256
    vae_stride = 8
    patch_size = 2
    diffloss_d = 8
    diffloss_w = 1536
    num_sampling_steps = '50'
    model_type = "fluid_large"
    args.start_step = 5
    args.fresh_t = 9

    model_checkpoint_path = "./models/checkpoint-512.pth"
    # model_checkpoint_path = '/data/xianfeng/code/mar/fluid_cache_1024/checkpoint-last.pth'

    model = fluid.__dict__[model_type](
        vae_stride=vae_stride,
        patch_size=patch_size,
        vae_embed_dim=16,
        mask_ratio_min=0.7,
        text_drop_prob=0.1,
        attn_dropout=0.1,
        proj_dropout=0.1,
        diffloss_d=diffloss_d,
        diffloss_w=diffloss_w,
        max_length=128,
        num_sampling_steps=num_sampling_steps,
        token_cache=args.token_cache,
        cfg_cache=args.cfg_cache,
    ).cuda()

    checkpoint = torch.load(model_checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model = model.cuda()

    # download and load the vae
    vae_checkpoint_path = './models/stabilityai/stable-diffusion-3.5-large'

    vae = AutoencoderKL.from_pretrained(os.path.join(vae_checkpoint_path, "vae")).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False

    max_length = 128
    mt5_cache_dir = './cache/flan-t5-xxl'
    t5_emb = T5_Embedding(mt5_cache_dir, mt5_cache_dir, max_length).cuda()

    # set up user-specified or default values for generation
    seed = 4042 # 1024 # 24
    torch.manual_seed(seed)
    np.random.seed(seed)

    num_ar_steps = 16
    cfg_scale = 4
    cfg_schedule = "constant"
    temperature = 1.0
    text_prompt = [
        'A woman stands on a city street, wearing an elegant white lace dress with intricate patterns and a deep neckline. Her long, wavy hair frames her face as she gazes into the distance.',
        'A close-up portrait of a person with long, wavy blonde hair styled in soft curls. The individual is wearing bold red lipstick and dramatic eye makeup, including dark eyeliner and mascara. The background is a neutral gray, emphasizing the subject\'s features.',
        'A man in a top hat and formal attire stands confidently in front of a large, rusted chain-link fence. The chains are thick and interlocked, creating a textured backdrop. The image has a vintage, sepia-toned quality.',
        'A light-colored bear stands on a moss-covered rock amidst a flowing river, surrounded by lush greenery and vibrant foliage.',
        'A group of white horses gallops through shallow water under a dramatic sky, creating splashes as they move. The scene is bathed in warm, golden light, highlighting the horses\' graceful motion.',
        'A fisherman sits on a bamboo raft by the riverbank, holding a fishing net, with a lantern hanging nearby. A cormorant stands on the water behind him, and the backdrop features towering karst mountains under a warm, sunset sky.',
        'A woman stands on a city street, wearing an elegant white lace dress with intricate patterns and a deep neckline. Her long, wavy hair frames her face as she gazes into the distance.',
        'A close-up portrait of a person with long, wavy blonde hair styled in soft curls. The individual is wearing bold red lipstick and dramatic eye makeup, including dark eyeliner and mascara. The background is a neutral gray, emphasizing the subject\'s features.',
        'A man in a top hat and formal attire stands confidently in front of a large, rusted chain-link fence. The chains are thick and interlocked, creating a textured backdrop. The image has a vintage, sepia-toned quality.',
        'A light-colored bear stands on a moss-covered rock amidst a flowing river, surrounded by lush greenery and vibrant foliage.',
        'A group of white horses gallops through shallow water under a dramatic sky, creating splashes as they move. The scene is bathed in warm, golden light, highlighting the horses\' graceful motion.',
        'A fisherman sits on a bamboo raft by the riverbank, holding a fishing net, with a lantern hanging nearby. A cormorant stands on the water behind him, and the backdrop features towering karst mountains under a warm, sunset sky.',

    ]

    text_emb = t5_emb(text_prompt)

    # generate the tokens and images
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            sampled_tokens = model.sample_tokens(
                bsz=len(text_prompt), num_iter=num_ar_steps,
                cfg=cfg_scale, cfg_schedule=cfg_schedule,
                texts=text_emb, temperature=temperature,
                height=512, width=512, progress=True, args = args
            )
            if vae.config.shift_factor is not None:
                samples = sampled_tokens / vae.config.scaling_factor + vae.config.shift_factor
            else:
                samples = sampled_tokens / vae.config.scaling_factor
            output_samples = vae.decode(samples).sample

    output_dir = "output_512_dpo_cherry"
    os.makedirs(output_dir, exist_ok=True)

    # save the images
    image_path = os.path.join(output_dir, "sampled_image.png")
    samples_per_row = 3

    save_image(
        output_samples, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
    )
    output_samples = (output_samples / 2 + 0.5).clamp(0, 1)
    output_samples = output_samples.cpu().float()
    import cv2
    save_folder = output_dir

    for b_id in range(output_samples.size(0)):
        gen_img = np.round(np.clip(output_samples[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
        gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
        cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(b_id).zfill(5))), gen_img)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    print(f"cfg_cache: {args.cfg_cache}, token_cache: {args.token_cache} ")
    main(args)