import sys
import random
import os
import numpy as np
from typing import List
from dataclasses import dataclass, field

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as tf
from torch.utils.checkpoint import checkpoint

from mvdream.camera_utils import convert_opengl_to_blender, normalize_camera
from mvdream.model_zoo import build_model

import threestudio
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseModule
from threestudio.utils.misc import C, cleanup, parse_version
from threestudio.utils.typing import *
import contextlib

def hash_prompt(model: str, prompt: str) -> str:
    import hashlib

    identifier = f"{model}-{prompt}"
    return hashlib.md5(identifier.encode()).hexdigest()

@threestudio.register("multiview-reward-embedding-checkpoint-guidance")
class MultiviewDiffusionGuidance(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        model_name: str = "sd-v2.1-base-4view"  # check mvdream.model_zoo.PRETRAINED_MODELS
        ckpt_path: Optional[str] = None  # path to local checkpoint (None for loading from url)
        pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base"
        negative_prompt: str = "ugly, bad anatomy, blurry, pixelated obscure, unnatural colors, poor lighting, dull, and unclear, cropped, lowres, low quality, artifacts, duplicate, morbid, mutilated, poorly drawn face, deformed, dehydrated bad proportions"
        guidance_scale: float = 7.5
        grad_clip: Optional[
            Any
        ] = None  # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
        half_precision_weights: bool = True

        min_step_percent: float = 0.02
        max_step_percent: float = 0.98

        camera_condition_type: str = "rotation"
        view_dependent_prompting: bool = False

        n_view: int = 4
        image_size: int = 256
        recon_loss: bool = True
        recon_std_rescale: float = 0.5

        reward_model: str = "hpsv2-score"  # ['hpsv2-score', 'imagereward-score', 'brique-score', 'reward3d-score']
        beta_dpo: float = 0

        token_len: int = 32
        use_unet_checkpointing: bool = False

    cfg: Config

    def configure(self) -> None:
        threestudio.info(f"Loading Multiview Diffusion ...")
        self.global_step = 0
        self.model = build_model(self.cfg.model_name, ckpt_path=self.cfg.ckpt_path)
        for p in self.model.parameters():
            p.requires_grad_(False)

        self.num_train_timesteps = 1000
        min_step_percent = C(self.cfg.min_step_percent, 0, 0)
        max_step_percent = C(self.cfg.max_step_percent, 0, 0)
        self.min_step = int(self.num_train_timesteps * min_step_percent)
        self.max_step = int(self.num_train_timesteps * max_step_percent)
        self.grad_clip_val: Optional[float] = None
        self.weights_dtype = (
            torch.float16 if self.cfg.half_precision_weights else torch.float32
        )

        # reward model
        self.score_func = threestudio.find(self.cfg.reward_model)({})
        print(f"Initialize Reward Model from {self.cfg.reward_model}")

        self.to(self.device)
        
        self.alphas: Float[Tensor, "T"] = self.model.alphas_cumprod**0.5
        self.sigmas: Float[Tensor, "T"] = (1 - self.model.alphas_cumprod) ** 0.5
        # log SNR
        self.lambdas: Float[Tensor, "T"] = self.sigmas / self.alphas
        # self.model = self.model.to(self.weights_dtype)
        # self.learnable_text = torch.nn.Parameter(torch.randn(1, int(self.cfg.token_len), 1024))
        self._cache_dir = ".threestudio_cache/text_embeddings"
        uncond_text_embeddings = self.load_from_cache(self.cfg.negative_prompt)[
            None, ...
        ]
        self.learnable_text = torch.nn.Parameter(uncond_text_embeddings)
        # self.learnable_text = uncond_text_embeddings
        
        # self.beta = self.model.betas
        # self.alpha = 1 - self.beta

        threestudio.info(f"Loaded Multiview Diffusion!")
        
    def load_from_cache(self, prompt):
        cache_path = os.path.join(
            self._cache_dir,
            f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt",
        )
        if not os.path.exists(cache_path):
            raise FileNotFoundError(
                f"Text embedding file {cache_path} for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] not found."
            )
        return torch.load(cache_path, map_location=self.device)

    def get_camera_cond(self,
                        camera: Float[Tensor, "B 4 4"],
                        fovy=None,
                        ):
        # Note: the input of threestudio is already blender coordinate system
        # camera = convert_opengl_to_blender(camera)
        if self.cfg.camera_condition_type == "rotation":  # normalized camera
            camera = normalize_camera(camera)
            camera = camera.flatten(start_dim=1)
        else:
            raise NotImplementedError(f"Unknown camera_condition_type={self.cfg.camera_condition_type}")
        return camera

    def encode_images(
            self, imgs: Float[Tensor, "B 3 256 256"]
    ) -> Float[Tensor, "B 4 32 32"]:
        input_dtype = imgs.dtype
        imgs = imgs * 2.0 - 1.0
        latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs))
        return latents.to(input_dtype)  # [B, 4, 32, 32] Latent space image

    def decode_latents(
            self,
            latents
    ):
        input_dtype = latents.dtype
        x_sample = self.model.decode_first_stage(latents)
        x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
        return x_sample.to(input_dtype)

    def process_noise_pred(
            self, noise_pred, latents_noisy, t, prompt,
            prompt_utils: PromptProcessorOutput,
            elevation: Float[Tensor, "B"],
            azimuth: Float[Tensor, "B"],
            camera_distances: Float[Tensor, "B"],
            normal_world: Optional[Float[Tensor, "B 3 H W"]] = None,
            opacity_map: Optional[Float[Tensor, "B 1 H W"]] = None,
            c2w: Optional[Float[Tensor, "B 4 4"]] = None,
    ):
        noise_pred_text, noise_pred_null = noise_pred.chunk(2)
        cfg_term = self.cfg.guidance_scale * (noise_pred_text - noise_pred_null)
        noise_pred = noise_pred_null + cfg_term
        noise_pred_img = noise_pred_null + 7.5 * (noise_pred_text - noise_pred_null)
        pred_original_sample = self.model.predict_start_from_noise(latents_noisy, t, noise_pred_img)
        hat_x_t = self.decode_latents(pred_original_sample)
        # print(hat_x_t.shape)
        score = self.score_func(
            hat_x_t, prompt, normal_world, opacity_map, c2w,
            prompt_utils, elevation, azimuth, camera_distances,
        )
        return score, noise_pred, noise_pred_text, noise_pred_null, cfg_term.norm()
    
    # def train_embedding(
    #     self,
    #     latents: Float[Tensor, "..."],
    #     context: Float[Tensor, "..."],
    #     noise: Float[Tensor, "..."],
    #     t: Float[Tensor, "..."],
    # ):
    #     B = latents.shape[0]
    #     latents = latents.detach()
        
    #     noisy_latents = self.model.q_sample(latents, t, noise)
    #     target = noise
    #     noise_pred = self.model.apply_model(noisy_latents, t, context)
    #     return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")

    def forward(
            self,
            rgb: Float[Tensor, "B H W C"],
            prompt_utils: PromptProcessorOutput,
            elevation: Float[Tensor, "B"],
            azimuth: Float[Tensor, "B"],
            camera_distances: Float[Tensor, "B"],
            c2w: Float[Tensor, "B 4 4"],
            t_perc_ref: Float[Tensor, "B"],
            rgb_as_latents: bool = False,
            fovy=None,
            # timestep=None,
            noise=None,
            text_embeddings=None,
            input_is_latent=False,
            normal_world: Optional[Float[Tensor, "B 3 H W"]] = None,
            opacity_map: Optional[Float[Tensor, "B 1 H W"]] = None,
            **kwargs,
    ):
        batch_size = rgb.shape[0]
        camera = c2w
        prompt = prompt_utils.prompt

        rgb = rgb.to(self.weights_dtype)
        rgb_BCHW = rgb.permute(0, 3, 1, 2)
        
        t_ref = torch.round(self.num_train_timesteps * t_perc_ref).to(
            dtype=torch.long, device=self.device
        )
        t_ref_single = t_ref[:1]
        
        
        # print(self.learnable_text.grad)
        
            
            # 将 learnable_text 转换为 buffer，不再计算梯度
            # embedding_value = self.learnable_text.data
            # del self.learnable_text
            # self.register_buffer("learnable_text", embedding_value)
            # 

        if text_embeddings is None:
            text_embeddings = prompt_utils.get_text_embeddings(
                elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting,
                learnable_text=self.learnable_text
            )
            

        if input_is_latent:
            latents = rgb_BCHW
        else:
            latents: Float[Tensor, "B 4 64 64"]
            if rgb_as_latents:
                latents = F.interpolate(rgb_BCHW, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
            else:
                # interp to 512x512 to be fed into vae.
                pred_rgb = F.interpolate(rgb_BCHW, (self.cfg.image_size, self.cfg.image_size), mode='bilinear',
                                         align_corners=False)
                # encode image into latents with vae, requires grad!
                latents = self.encode_images(pred_rgb)
        
        
        if self.global_step < 5000:
            # ctx = torch.no_grad()
            # self.learnable_text.requires_grad_(False)
            use_unet_checkpointing = False
            # if self.global_step == 5000:
            #     torch.cuda.empty_cache()
        else:
            use_unet_checkpointing = self.cfg.use_unet_checkpointing
            # self.learnable_text.requires_grad_(True)
        ctx = contextlib.nullcontext()
            
        # latents.requires_grad_(True)
        

        # sample timestep
        # if timestep is None:
        #     t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=latents.device)
        # else:
        #     assert timestep >= 0 and timestep < self.num_train_timesteps
        #     t = torch.full([1], timestep, dtype=torch.long, device=latents.device)
        # t_expand = t.repeat(text_embeddings.shape[0])
        # print(t_expand.shape)
        t = t_ref_single
        t_expand = torch.cat([t_ref] * 4)
        # print(noise.shape)
        
        
        # with torch.no_grad():
        with ctx:
            pos_text_embeddings, uncond_text_embeddings = text_embeddings.chunk(2)
            text_embeddings = torch.cat(
                [pos_text_embeddings, uncond_text_embeddings, pos_text_embeddings, uncond_text_embeddings]
            )
            # add noise
            # noise_1 = torch.randn_like(latents)
            noise_1 = noise
            # print(noise_1.shape)
            latents_noisy_1 = self.model.q_sample(latents, t, noise_1)
            noise_2 = torch.randn_like(latents)
            # noise_2 = noise[batch_size:]
            latents_noisy_2 = self.model.q_sample(latents, t, noise_2)
            # pred noise
            latent_model_input = torch.cat([latents_noisy_1] * 2 + [latents_noisy_2] * 2)
            # save input tensors for UNet
            if camera is not None:
                camera = self.get_camera_cond(camera, fovy)
                camera_cond = camera.repeat(4, 1).to(text_embeddings)
                # print(camera_cond.shape)
                context = {"context": text_embeddings, "camera": camera_cond, "num_frames": self.cfg.n_view}
            else:
                context = {"context": text_embeddings}
            
            if use_unet_checkpointing:
                def unet_wrapper(*args):
                    _latent_model_input, _t_expand, _context = args
                    return self.model.apply_model(_latent_model_input, _t_expand, _context)

                noise_pred = checkpoint(unet_wrapper, latent_model_input, t_expand, context, use_reentrant=True)
            else:
                noise_pred = self.model.apply_model(latent_model_input, t_expand, context)

            noise_pred_1, noise_pred_2 = noise_pred.chunk(2)
            score_1, noise_pred_1, noise_pred_text_1, noise_pred_null_1, cfg_norm_1 = self.process_noise_pred(
                noise_pred_1, latents_noisy_1, t, prompt, prompt_utils, elevation, azimuth, camera_distances,
                normal_world=normal_world,
                opacity_map=opacity_map,
                c2w=c2w,
            )
            score_2, noise_pred_2, noise_pred_text_2, noise_pred_null_2, cfg_norm_2 = self.process_noise_pred(
                noise_pred_2, latents_noisy_2, t, prompt, prompt_utils, elevation, azimuth, camera_distances, 
                normal_world=normal_world,
                opacity_map=opacity_map,
                c2w=c2w,
            )
        win_mask = score_1 >= score_2


        win_mask_1, win_mask_2 = win_mask, torch.logical_not(win_mask)

        noise_lose = torch.zeros_like(noise_1)
        noise_lose[win_mask_1] = noise_2[win_mask_1]
        noise_lose[win_mask_2] = noise_1[win_mask_2]

        score_gap, score_lose = torch.zeros_like(score_1), torch.zeros_like(score_1)
        score_gap[win_mask_1] = (score_1[win_mask_1] - score_2[win_mask_1])
        score_gap[win_mask_2] = (score_2[win_mask_2] - score_1[win_mask_2])
        
        socre_win = torch.zeros_like(score_1)
        socre_win[win_mask_1] = score_1[win_mask_1]
        socre_win[win_mask_2] = score_2[win_mask_2]
            
        inside_term = score_gap[win_mask_1]
            
        win_noise1 = noise_pred_1[win_mask_1] - noise_1[win_mask_1]
        lose_noise1 = noise_pred_2[win_mask_1] - noise_2[win_mask_1]
        win_noise2 = noise_pred_2[win_mask_2] - noise_2[win_mask_2]
        lose_noise2 = noise_pred_1[win_mask_2] - noise_1[win_mask_2]
            
            
            
        reward_sds = torch.zeros_like(noise_1)
        noise_pred = torch.zeros_like(noise_1)
        noise_pred_text = torch.zeros_like(noise_pred_text_1)
        latents_noisy = torch.zeros_like(latents_noisy_1)
        preference_term = (1 - F.sigmoid(inside_term)[:, None, None, None]) * (noise_pred_1[win_mask_1] - noise_pred_2[win_mask_1])
        beta = preference_term.norm() / cfg_norm_1
        reward_sds[win_mask_1] = win_noise1 + preference_term / beta
        noise_pred[win_mask_1] = reward_sds[win_mask_1] + noise_1[win_mask_1]
        latents_noisy[win_mask_1] = latents_noisy_1[win_mask_1]
        noise_pred_text[win_mask_1] = noise_pred_text_1[win_mask_1]

        inside_term = score_gap[win_mask_2]
        preference_term = (1 - F.sigmoid(inside_term)[:, None, None, None]) * (noise_pred_2[win_mask_2] - noise_pred_1[win_mask_2])
        beta = preference_term.norm() / cfg_norm_2
        reward_sds[win_mask_2] = win_noise2 + preference_term / beta
        noise_pred[win_mask_2] = reward_sds[win_mask_2] + noise_2[win_mask_2]
        latents_noisy[win_mask_2] = latents_noisy_2[win_mask_2]
        noise_pred_text[win_mask_2] = noise_pred_text_2[win_mask_2]
        if self.cfg.recon_loss:
            # reconstruct x0
                       
            latents_recon = self.model.predict_start_from_noise(
                latents_noisy, t, noise_pred
            )

            # clip or rescale x0
            if self.cfg.recon_std_rescale > 0:
                latents_recon_nocfg = self.model.predict_start_from_noise(
                    latents_noisy, t, noise_pred_text
                )
                latents_recon_nocfg_reshape = latents_recon_nocfg.view(
                    -1, self.cfg.n_view, *latents_recon_nocfg.shape[1:]
                )
                latents_recon_reshape = latents_recon.view(
                    -1, self.cfg.n_view, *latents_recon.shape[1:]
                )
                factor = (
                    latents_recon_nocfg_reshape.std([1, 2, 3, 4], keepdim=True) + 1e-8
                ) / (latents_recon_reshape.std([1, 2, 3, 4], keepdim=True) + 1e-8)

                latents_recon_adjust = latents_recon.clone() * factor.squeeze(
                    1
                ).repeat_interleave(self.cfg.n_view, dim=0)
                latents_recon = (
                    self.cfg.recon_std_rescale * latents_recon_adjust
                    + (1 - self.cfg.recon_std_rescale) * latents_recon
                )

            # x0-reconstruction loss from Sec 3.2 and Appendix
            loss = (
                0.5
                * F.mse_loss(latents, latents_recon.detach(), reduction="sum")
                / latents.shape[0]
            )
            # grad = torch.autograd.grad(loss, latents, retain_graph=True)[0]
            grad = reward_sds
        else:
            grad = reward_sds
            # print(grad.norm())
            w = (1 - self.model.alphas_cumprod[t])
            # w = self.lambdas[t_ref].view(-1, 1, 1, 1)
            # print(f"w: {w}")
            grad = w * grad
            # print(grad.norm())

            # clip grad for stable training?
            if self.grad_clip_val is not None:
                grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
            grad = torch.nan_to_num(grad)

            target = (latents - grad).detach()
        # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
            loss = 0.5 * F.mse_loss(latents, target, reduction="sum") / latents.shape[0] 
        loss_embedding = F.relu(4-socre_win).mean()


        return {
            "loss_sds": loss,
            "grad_norm": grad.norm(),
            "loss_embedding": loss_embedding,
        }

    @torch.cuda.amp.autocast(enabled=False)
    def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
        self.min_step = int(self.num_train_timesteps * min_step_percent)
        self.max_step = int(self.num_train_timesteps * max_step_percent)

    def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
        self.global_step = global_step
        min_step_percent = C(self.cfg.min_step_percent, epoch, global_step)
        max_step_percent = C(self.cfg.max_step_percent, epoch, global_step)
        self.min_step = int(self.num_train_timesteps * min_step_percent)
        self.max_step = int(self.num_train_timesteps * max_step_percent)
