import random
import torchvision.transforms as transforms

import numpy as np
import torch

import imageio
import os
import PIL.Image
from typing import Union, Optional, List
from peft import get_peft_model_state_dict

from vm.modules.posemb_layers import get_nd_rotary_pos_embed
from vm.vae import AutoencoderKLCausal3D

from pathlib import Path
from einops import rearrange
from PIL import Image

from vm.constants import PRECISION_TO_TYPE
from safetensors.torch import load_file


def convert_kohya_to_peft_keys(
    kohya_dict: dict,
    kohya_prefix="",
    peft_prefix: str = "base_model.model",
    device="cpu",
) -> dict:
    peft_dict = {}
    for k, v in kohya_dict.items():
        if ".alpha" in k:
            continue
        new_key = k.replace(f"{kohya_prefix}_lora_", f"{peft_prefix}.")
        new_key = new_key.replace("single_blocks_", "single_blocks.")
        new_key = new_key.replace("double_blocks_", "double_blocks.")
        new_key = new_key.replace("_img_attn_proj", ".img_attn_proj")
        new_key = new_key.replace("_img_attn_qkv", ".img_attn_qkv")
        new_key = new_key.replace("_img_mlp_fc", ".img_mlp.fc")
        new_key = new_key.replace("_txt_mlp_fc", ".txt_mlp.fc")
        new_key = new_key.replace("_img_mod", ".img_mod")
        new_key = new_key.replace("_txt", ".txt")
        new_key = new_key.replace("_modulation", ".modulation")
        new_key = new_key.replace("_linear", ".linear")
        new_key = new_key.replace("lora_down", "lora_A.default")
        new_key = new_key.replace("lora_up", "lora_B.default")
        new_key = new_key.replace(
            "_individual_token_refiner_blocks_", ".individual_token_refiner.blocks."
        )
        new_key = new_key.replace("_mlp_fc", ".mlp.fc")

        peft_dict[new_key] = v.to(device)
    return peft_dict


def load_lora(model, lora_path, device):
    kohya_weights = load_file(lora_path)
    peft_weights = convert_kohya_to_peft_keys(
        kohya_weights, kohya_prefix="Hunyuan_video_I2V", device=device
    )
    model.load_state_dict(peft_weights, strict=False)
    return model


def black_image(width, height):
    black_image = Image.new("RGB", (width, height), (0, 0, 0))
    return black_image


def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def get_cond_latents(args, latents, vae):
    """get conditioned latent by decode and encode the first frame latents"""
    first_image_latents = latents[:, :, 0, ...] if len(latents.shape) == 5 else latents
    first_image_latents = 1 / vae.config.scaling_factor * first_image_latents
    first_images = vae.decode(
        first_image_latents.unsqueeze(2).to(vae.dtype), return_dict=False
    )[0]
    first_images = first_images.squeeze(2)
    first_images = (first_images / 2 + 0.5).clamp(0, 1)
    first_images = first_images.cpu().permute(0, 2, 3, 1).float().numpy()
    first_images = numpy_to_pil(first_images)

    image_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    )
    first_images_pixel_values = [image_transform(image) for image in first_images]
    first_images_pixel_values = (
        torch.cat(first_images_pixel_values).unsqueeze(0).unsqueeze(2).to(vae.device)
    )

    vae_dtype = PRECISION_TO_TYPE[args.vae_precision]
    with torch.autocast(
        device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32
    ):
        cond_latents = vae.encode(
            first_images_pixel_values
        ).latent_dist.sample()  # B, C, F, H, W
        cond_latents.mul_(vae.config.scaling_factor)

    return cond_latents


def get_cond_images(args, latents, vae, is_uncond=False):
    """get conditioned images by decode the first frame latents"""
    sematic_image_latents = (
        latents[:, :, 0, ...] if len(latents.shape) == 5 else latents
    )
    sematic_image_latents = 1 / vae.config.scaling_factor * sematic_image_latents
    semantic_images = vae.decode(
        sematic_image_latents.unsqueeze(2).to(vae.dtype), return_dict=False
    )[0]
    semantic_images = semantic_images.squeeze(2)
    semantic_images = (semantic_images / 2 + 0.5).clamp(0, 1)
    semantic_images = semantic_images.cpu().permute(0, 2, 3, 1).float().numpy()
    semantic_images = numpy_to_pil(semantic_images)
    if is_uncond:
        semantic_images = [
            black_image(img.size[0], img.size[1]) for img in semantic_images
        ]

    return semantic_images


def load_state_dict(args, model, logger):
    pretrained_model_path = Path(args.model_base)
    if not pretrained_model_path.exists():
        raise ValueError(f"`models_root` not exists: {pretrained_model_path}")

    load_key = args.load_key
    if args.i2v_mode:
        dit_weight = Path(args.i2v_dit_weight)
    else:
        dit_weight = Path(args.dit_weight)

    if dit_weight is None:
        model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
        files = list(model_dir.glob("*.pt"))
        if len(files) == 0:
            raise ValueError(f"No model weights found in {model_dir}")
        if str(files[0]).startswith("pytorch_model_"):
            model_path = dit_weight / f"pytorch_model_{load_key}.pt"
            bare_model = True
        elif any(str(f).endswith("_model_states.pt") for f in files):
            files = [f for f in files if str(f).endswith("_model_states.pt")]
            model_path = files[0]
            if len(files) > 1:
                logger.warning(
                    f"Multiple model weights found in {dit_weight}, using {model_path}"
                )
            bare_model = False
        else:
            raise ValueError(
                f"Invalid model path: {dit_weight} with unrecognized weight format: "
                f"{list(map(str, files))}. When given a directory as --dit-weight, only "
                f"`pytorch_model_*.pt`(provided by HunyuanVideo official) and "
                f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
                f"specific weight file, please provide the full path to the file."
            )
    else:
        if dit_weight.is_dir():
            files = list(dit_weight.glob("*.pt"))
            if len(files) == 0:
                raise ValueError(f"No model weights found in {dit_weight}")
            if str(files[0]).startswith("pytorch_model_"):
                model_path = dit_weight / f"pytorch_model_{load_key}.pt"
                bare_model = True
            elif any(str(f).endswith("_model_states.pt") for f in files):
                files = [f for f in files if str(f).endswith("_model_states.pt")]
                model_path = files[0]
                if len(files) > 1:
                    logger.warning(
                        f"Multiple model weights found in {dit_weight}, using {model_path}"
                    )
                bare_model = False
            else:
                raise ValueError(
                    f"Invalid model path: {dit_weight} with unrecognized weight format: "
                    f"{list(map(str, files))}. When given a directory as --dit-weight, only "
                    f"`pytorch_model_*.pt`(provided by HunyuanVideo official) and "
                    f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
                    f"specific weight file, please provide the full path to the file."
                )
        elif dit_weight.is_file():
            model_path = dit_weight
            bare_model = "unknown"
        else:
            raise ValueError(f"Invalid model path: {dit_weight}")

    if not model_path.exists():
        raise ValueError(f"model_path not exists: {model_path}")
    logger.info(f"Loading torch model {model_path}...")
    state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)

    if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
        bare_model = False
    if bare_model is False:
        if load_key in state_dict:
            state_dict = state_dict[load_key]
        else:
            raise KeyError(
                f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
                f"are: {list(state_dict.keys())}."
            )
    model.load_state_dict(state_dict, strict=True)
    return model


class set_worker_seed_builder:
    def __init__(self, global_rank):
        self.global_rank = global_rank

    def __call__(self, worker_id):
        set_manual_seed(torch.initial_seed() % (2 ** 32 - 1))


def set_reproducibility(enable, global_seed=None):
    if enable:
        # Configure the seed for reproducibility
        set_manual_seed(global_seed)
    # Set following debug environment variable
    # See the link for details: https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    # Cudnn benchmarking
    torch.backends.cudnn.benchmark = not enable
    # Use deterministic algorithms in PyTorch
    torch.use_deterministic_algorithms(enable)

    # LSTM and RNN networks are not deterministic


def prepare_model_inputs(
    args,
    batch: tuple,
    device: Union[int, str],
    model,
    vae,
    text_encoder,
    text_encoder_2=None,
    rope_theta_rescale_factor: Union[float, List[float]] = 1.0,
    rope_interpolation_factor: Union[float, List[float]] = 1.0,
):
    media, latents, *batch_args = batch
    if len(batch_args) == 3:
        text_ids, text_mask, kwargs = batch_args
        text_ids_2, text_mask_2 = None, None
    elif len(batch_args) == 5:
        text_ids, text_mask, text_ids_2, text_mask_2, kwargs = batch_args
    else:
        raise ValueError(f"Unexpected batch_args.")
    data_type = kwargs["type"][0]

    # Move batch to device
    media = media.to(device)
    latents = latents.to(device)
    text_ids = text_ids.to(device)
    text_mask = text_mask.to(device)

    # ======================================== Encode media ======================================
    # Used for 3D VAE with 2D inputs(image).
    # Prepare media shape for 2D/3D VAE
    if len(latents.shape) == 1:
        if len(media.shape) == 4:
            # media is a batch of image with shape [b, c, h, w]
            if isinstance(vae, AutoencoderKLCausal3D):
                media = media.unsqueeze(2)  # [b, c, 1, h, w]
        elif len(media.shape) == 5:
            # media is a batch of video with shape [b, c, f, h, w]
            if not isinstance(vae, AutoencoderKLCausal3D):
                media = rearrange(media, "b c f h w -> (b f) c h w")
        else:
            raise ValueError(
                f"Only support media with shape (b, c, h, w) or (b, c, f, h, w), but got {media.shape}."
            )

        vae_dtype = PRECISION_TO_TYPE[args.vae_precision]
        with torch.autocast(
            device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32
        ):
            latents = vae.encode(media).latent_dist.sample()
            if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
                latents.sub_(vae.config.shift_factor).mul_(vae.config.scaling_factor)
            else:
                latents.mul_(vae.config.scaling_factor)
    elif len(latents.shape) == 5 or len(latents.shape) == 4:  # Using video/image cache
        latents = (
            latents * vae.config.scaling_factor
        )  # vae cache is not multiplied by scaling_factor
    else:
        raise ValueError(
            f"Only support media/latent with shape (b, c, h, w) or (b, c, f, h, w), but got {media.shape} {latents.shape}."
        )

    cond_latents = get_cond_latents(args, latents, vae)
    is_uncond = (
        torch.tensor(1).to(torch.int64)
        if random.random() < args.sematic_cond_drop_p
        else torch.tensor(0).to(torch.int64)
    )
    semantic_images = get_cond_images(args, latents, vae, is_uncond=is_uncond)

    # ======================================== Encode text ======================================
    # Autocast is handled by text_encoder itself.
    # Whether to apply text_mask is determined by args.use_attention_mask.
    text_outputs = text_encoder.encode(
        {"input_ids": text_ids, "attention_mask": text_mask},
        data_type=batch_args[-1]["type"][0],
        semantic_images=semantic_images,
    )
    text_states = text_outputs.hidden_state
    text_mask = text_outputs.attention_mask
    text_states_2 = (
        text_encoder_2.encode(
            {"input_ids": text_ids_2, "attention_mask": text_mask_2},
            data_type=data_type,
        ).hidden_state
        if text_encoder_2 is not None
        else None
    )

    # ======================================== Build RoPE ======================================
    target_ndim = 3  # n-d RoPE
    ndim = len(latents.shape) - 2
    latents_size = list(latents.shape[-ndim:])
    freqs_cos, freqs_sin = get_rope_freq_from_size(
        args,
        model,
        latents_size,
        ndim,
        target_ndim,
        rope_theta_rescale_factor=rope_theta_rescale_factor,
        rope_interpolation_factor=rope_interpolation_factor,
    )

    # ===================================== Pack model kwargs ==================================
    model_kwargs = dict(
        text_states=text_states,  # [b, 256, 4096]
        text_mask=text_mask,  # [b, 256]
        text_states_2=text_states_2,  # [b, 768]
        freqs_cos=freqs_cos,  # [seqlen, head_dim]
        freqs_sin=freqs_sin,  # [seqlen, head_dim]
        return_dict=True,
    )

    return latents, model_kwargs, freqs_cos.shape[0], cond_latents


def format_params(params):
    if params < 1e6:
        return f"{params} (less than 1M)"
    elif params < 1e9:
        return f"{params / 1e6:.2f}M"
    else:
        return f"{params / 1e9:.2f}B"


def set_manual_seed(global_seed):
    random.seed(global_seed)
    np.random.seed(global_seed)
    torch.manual_seed(global_seed)


def get_rope_freq_from_size(
    args,
    model,
    latents_size,
    ndim,
    target_ndim,
    rope_theta_rescale_factor=1.0,
    rope_interpolation_factor=1.0,
):

    if isinstance(model.patch_size, int):
        assert all(s % model.patch_size == 0 for s in latents_size), (
            f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
            f"but got {latents_size}."
        )
        rope_sizes = [s // model.patch_size for s in latents_size]

    elif isinstance(model.patch_size, list):
        assert all(
            s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)
        ), (
            f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
            f"but got {latents_size}."
        )
        rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]

    if len(rope_sizes) != target_ndim:
        rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
    head_dim = model.hidden_size // model.heads_num
    rope_dim_list = model.rope_dim_list

    if rope_dim_list is None:
        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
    assert (
        sum(rope_dim_list) == head_dim
    ), "sum(rope_dim_list) should equal to head_dim of attention layer"

    freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
        rope_dim_list,
        rope_sizes,
        theta=args.rope_theta,
        use_real=True,
        theta_rescale_factor=rope_theta_rescale_factor,
        interpolation_factor=rope_interpolation_factor,
    )

    return freqs_cos, freqs_sin


# copy from https://github.com/huggingface/diffusers/blob/ec9bfa9e148b7764137dd92247ce859d915abcb0/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py#L258
# get kohya lora state dict
def get_module_kohya_state_dict(module, prefix, dtype, adapter_name="default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in get_peft_model_state_dict(
        module, adapter_name=adapter_name
    ).items():
        kohya_key = peft_key.replace("base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)

        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(
                module.peft_config[adapter_name].lora_alpha
            ).to(dtype)

    return kohya_ss_state_dict


# get diffusers lora state dict
def get_module_diffusers_state_dict(module, dtype, adapter_name="default"):
    diffusers_ss_state_dict = {}
    for peft_key, weight in get_peft_model_state_dict(
        module, adapter_name=adapter_name
    ).items():
        diffusers_key = peft_key.replace("base_model.model", "diffusion_model")
        diffusers_ss_state_dict[diffusers_key] = weight.to(dtype)

    return diffusers_ss_state_dict
