# ------------------------------------------
# VQ-Diffusion
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# written By Shuyang Gu
# ------------------------------------------

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../'))

import torch
import cv2
import argparse
import numpy as np
import torchvision
from PIL import Image

from image_synthesis.utils.io import load_yaml_config
from image_synthesis.modeling.build import build_model
from image_synthesis.utils.misc import get_model_parameters_info

class VQ_Diffusion():
    def __init__(self, config, path, imagenet_cf=False):
        self.info = self.get_model(ema=True, model_path=path, config_path=config, imagenet_cf=imagenet_cf)
        self.model = self.info['model']
        self.epoch = self.info['epoch']
        self.model_name = self.info['model_name']
        self.model = self.model.cuda()
        self.model.eval()
        for param in self.model.parameters(): 
            param.requires_grad=False

    def get_model(self, ema, model_path, config_path, imagenet_cf):
        if 'OUTPUT' in model_path: # pretrained model
            model_name = model_path.split(os.path.sep)[-3]
        else: 
            model_name = os.path.basename(config_path).replace('.yaml', '')

        config = load_yaml_config(config_path)

        if imagenet_cf:
            config['model']['params']['diffusion_config']['params']['transformer_config']['params']['class_number'] = 1001

        model = build_model(config)
        model_parameters = get_model_parameters_info(model)
        
        print(model_parameters)
        if os.path.exists(model_path):
            ckpt = torch.load(model_path, map_location="cpu")
        else:
            print("Model path: {} does not exist.".format(model_path))
            exit(0)
        if 'last_epoch' in ckpt:
            epoch = ckpt['last_epoch']
        elif 'epoch' in ckpt:
            epoch = ckpt['epoch']
        else:
            epoch = 0

        missing, unexpected = model.load_state_dict(ckpt["model"], strict=False)
        print('Model missing keys:\n', missing)
        print('Model unexpected keys:\n', unexpected)

        if ema==True and 'ema' in ckpt:
            print("Evaluate EMA model")
            ema_model = model.get_ema_model()
            missing, unexpected = ema_model.load_state_dict(ckpt['ema'], strict=False)
        
        return {'model': model, 'epoch': epoch, 'model_name': model_name, 'parameter': model_parameters}

    def inference_generate_sample_with_class(self, text, truncation_rate, save_root, batch_size, infer_speed=False, guidance_scale=1.0, return_tokens=False):
        os.makedirs(save_root, exist_ok=True)

        self.model.guidance_scale = guidance_scale

        data_i = {}
        data_i['label'] = [text]
        data_i['image'] = None
        condition = text

        str_cond = str(condition)
        save_root_ = os.path.join(save_root, str_cond)
        os.makedirs(save_root_, exist_ok=True)

        with torch.no_grad():
            model_out = self.model.generate_content(
                batch=data_i,
                filter_ratio=0,
                replicate=batch_size,
                content_ratio=1,
                return_att_weight=False,
                sample_type="top"+str(truncation_rate)+'r',
            ) # B x C x H x W
        
        # save results
        content = model_out['content']
        # Normalize to a tensor first (B, C, H, W)
        if isinstance(content, list):
            content = [c if isinstance(c, torch.Tensor) else torch.as_tensor(c) for c in content]
            content = torch.stack(content, dim=0)
        elif not isinstance(content, torch.Tensor):
            content = torch.as_tensor(content)
        if content.dtype != torch.float32:
            content = content.float()
        max_val = content.max()
        if torch.isfinite(max_val) and max_val > 1.0:
            content = content / 255.0
        
        # Get tokens if requested
        tokens = None
        if return_tokens and 'content_token' in model_out:
            tokens = model_out['content_token'].to('cpu').numpy()
        
        for b in range(content.shape[0]):
            cnt = b
            save_base_name = '{}'.format(str(cnt).zfill(6))
            save_path = os.path.join(save_root_, save_base_name+'.jpg')
            torchvision.utils.save_image(content[b].clamp(0, 1), save_path)
            
            # Save tokens if available
            if tokens is not None:
                token_save_path = os.path.join(save_root_, save_base_name+'_tokens.npy')
                np.save(token_save_path, tokens[b])
        
        return {'content': content, 'tokens': tokens}

    def inference_generate_sample_with_condition(self, text, truncation_rate, save_root, batch_size, infer_speed=False, guidance_scale=1.0, prior_rule=0, prior_weight=0, learnable_cf=True, return_tokens=False):
        os.makedirs(save_root, exist_ok=True)

        self.model.guidance_scale = guidance_scale
        self.model.learnable_cf = self.model.transformer.learnable_cf = learnable_cf # whether to use learnable classifier-free
        self.model.transformer.prior_rule = prior_rule      # inference rule: 0 for VQ-Diffusion v1, 1 for only high-quality inference, 2 for purity prior
        self.model.transformer.prior_weight = prior_weight  # probability adjust parameter, 'r' in Equation.11 of Improved VQ-Diffusion

        data_i = {}
        data_i['text'] = [text]
        data_i['image'] = None
        condition = text

        str_cond = str(condition)
        save_root_ = os.path.join(save_root, str_cond)
        os.makedirs(save_root_, exist_ok=True)

        if infer_speed != False:
            add_string = 'r,time'+str(infer_speed)
        else:
            add_string = 'r'
        with torch.no_grad():
            model_out = self.model.generate_content(
                batch=data_i,
                filter_ratio=0,
                replicate=batch_size,
                content_ratio=1,
                return_att_weight=False,
                sample_type="top"+str(truncation_rate)+add_string,
            ) # B x C x H x W
        
        # save results
        content = model_out['content']
        # Normalize to a tensor first (B, C, H, W)
        if isinstance(content, list):
            content = [c if isinstance(c, torch.Tensor) else torch.as_tensor(c) for c in content]
            content = torch.stack(content, dim=0)
        elif not isinstance(content, torch.Tensor):
            content = torch.as_tensor(content)
        if content.dtype != torch.float32:
            content = content.float()
        max_val = content.max()
        if torch.isfinite(max_val) and max_val > 1.0:
            content = content / 255.0
        
        # Get tokens if requested
        tokens = None
        if return_tokens and 'content_token' in model_out:
            tokens = model_out['content_token'].to('cpu').numpy()
        
        for b in range(content.shape[0]):
            cnt = b
            save_base_name = '{}'.format(str(cnt).zfill(6))
            save_path = os.path.join(save_root_, save_base_name+'.png')
            torchvision.utils.save_image(content[b].clamp(0, 1), save_path)
            
            # Save tokens if available
            if tokens is not None:
                token_save_path = os.path.join(save_root_, save_base_name+'_tokens.npy')
                np.save(token_save_path, tokens[b])
        
        return {'content': content, 'tokens': tokens}


if __name__ == '__main__':
    VQ_Diffusion_model = VQ_Diffusion(config='configs/ithq.yaml', path='/checkpoints/pretrained_model/ithq_learnable.pth')

    # Inference VQ-Diffusion
    # VQ_Diffusion_model.inference_generate_sample_with_condition("teddy bear playing in the pool", truncation_rate=0.86, save_root="RESULT", batch_size=4)

    # Inference Improved VQ-Diffusion with zero-shot classifier-free sampling
    # VQ_Diffusion_model.inference_generate_sample_with_condition("teddy bear playing in the pool", truncation_rate=1.0, save_root="RESULT", batch_size=4, guidance_scale=5.0, learnable_cf=False)
    # VQ_Diffusion_model.inference_generate_sample_with_condition("a long exposure photo of waterfall", truncation_rate=1.0, save_root="RESULT", batch_size=4, guidance_scale=5.0, learnable_cf=False)

    # Inference Improved VQ-Diffusion with learnable classifier-free sampling
    VQ_Diffusion_model.inference_generate_sample_with_condition("teddy bear playing in the pool", truncation_rate=1.0, save_root="RESULT", batch_size=4, guidance_scale=5.0)
    # VQ_Diffusion_model.inference_generate_sample_with_condition("a long exposure photo of waterfall", truncation_rate=1.0, save_root="RESULT", batch_size=4, guidance_scale=5.0)

    # Inference Improved VQ-Diffusion with fast/high-quality inference
    # VQ_Diffusion_model.inference_generate_sample_with_condition("a long exposure photo of waterfall", truncation_rate=0.86, save_root="RESULT", batch_size=4, infer_speed=0.5) # high-quality inference, 0.5x inference speed
    # VQ_Diffusion_model.inference_generate_sample_with_condition("a long exposure photo of waterfall", truncation_rate=0.86, save_root="RESULT", batch_size=4, infer_speed=2) # fast inference, 2x inference speed
    # infer_speed shoule be float in [0.1, 10], larger infer_speed means faster inference and smaller infer_speed means slower inference

    # Inference Improved VQ-Diffusion with purity sampling
    # VQ_Diffusion_model.inference_generate_sample_with_condition("a long exposure photo of waterfall", truncation_rate=0.86, save_root="RESULT", batch_size=4, prior_rule=2, prior_weight=1) # purity sampling

    # Inference Improved VQ-Diffusion with both learnable classifier-free sampling and fast inference
    # VQ_Diffusion_model.inference_generate_sample_with_condition("a long exposure photo of waterfall", truncation_rate=1.0, save_root="RESULT", batch_size=4, guidance_scale=5.0, infer_speed=2) # classifier-free guidance and fast inference




    # VQ_Diffusion_model = VQ_Diffusion(config='/checkpoints/pretrained_model/config_text.yaml', path='/checkpoints/pretrained_model/coco_learnable.pth')

    # Inference VQ-Diffusion
    # VQ_Diffusion_model.inference_generate_sample_with_condition("A group of elephants walking in muddy water", truncation_rate=0.86, save_root="RESULT", batch_size=4)

    # Inference Improved VQ-Diffusion with learnable classifier-free sampling
    # VQ_Diffusion_model.inference_generate_sample_with_condition("A group of elephants walking in muddy water", truncation_rate=1.0, save_root="RESULT", batch_size=4, guidance_scale=3.0)




    # Inference Improved VQ-Diffusion with zero-shot classifier-free sampling: load models without classifier-free fine-tune and set guidance_scale to > 1
    # VQ_Diffusion_model = VQ_Diffusion(config='/checkpoints/pretrained_model/config_text.yaml', path='/checkpoints/pretrained_model/coco_pretrained.pth')
    # VQ_Diffusion_model.inference_generate_sample_with_condition("A group of elephants walking in muddy water", truncation_rate=0.86, save_root="RESULT", batch_size=4, guidance_scale=3.0, learnable_cf=False)




    # Inference VQ-Diffusion
    # VQ_Diffusion_model = VQ_Diffusion(config='/checkpoints/pretrained_model/config_imagenet.yaml', path='/checkpoints/pretrained_model/imagenet_pretrained.pth')
    # VQ_Diffusion_model.inference_generate_sample_with_class(407, truncation_rate=0.86, save_root="RESULT", batch_size=4)


    # Inference Improved VQ-Diffusion with classifier-free sampling
    # VQ_Diffusion_model = VQ_Diffusion(config='configs/imagenet.yaml', path='/checkpoints/pretrained_model/imagenet_learnable.pth', imagenet_cf=True)
    # VQ_Diffusion_model.inference_generate_sample_with_class(407, truncation_rate=0.94, save_root="RESULT", batch_size=4, guidance_scale=1.5)
