# encoding = 'utf-8'
import os.path as osp

import math
from tqdm import tqdm
from rich.progress import track

from omegaconf import OmegaConf

import torch
import torch.nn as nn

from .talking_head_dit_v15 import TalkingHeadDiT_models
import sys
from ..schedulers.scheduling_ddim import DDIMScheduler
from ..schedulers.flow_matching2 import ModelSamplingDiscreteFlow
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__))))))

from src.criterions import mse_loss, velocity_loss, smooth_loss
loss_fn = nn.L1Loss()
scheduler_map = {
    "ddim": DDIMScheduler,
    "flow_matching": ModelSamplingDiscreteFlow
}
lip_dims=[18, 19, 20, 36, 37, 38, 42, 43, 44, 51, 52, 53, 57, 58, 59, 60, 61, 62]

class MotionDiffusion(nn.Module):
    def __init__(self, config, device="cuda", dtype=torch.float32, smo_wsize=3, loss_type="l2"):
        super().__init__()

        self.config = config
        self.smo_wsize = smo_wsize
        print(f"================================== Init Motion GeneratorV2 ==================================")
        print(OmegaConf.to_yaml(self.config))
        
        motion_gen_config = config.motion_generator
        motion_gen_params = motion_gen_config.params
        
        audio_proj_config = config.audio_projector
        audio_proj_params = audio_proj_config.params
        
        scheduler_config = config.noise_scheduler
        scheduler_params = scheduler_config.params

        self.device = device

        # init motion generator
        self.talking_head_dit = TalkingHeadDiT_models[config.model_name](
            input_dim           = motion_gen_params.input_dim * 2,
            output_dim          = motion_gen_params.output_dim,
            seq_len             = motion_gen_params.n_pred_frames,
            audio_unit_len      = audio_proj_params.sequence_length,
            audio_blocks        = audio_proj_params.blocks,
            audio_dim           = audio_proj_params.audio_feat_dim,
            audio_tokens        = audio_proj_params.context_tokens,
            audio_embedder_type = audio_proj_params.audio_embedder_type,
            audio_cond_dim      = audio_proj_params.audio_cond_dim,
            norm_type           = motion_gen_params.norm_type,
            qk_norm             = motion_gen_params.qk_norm,
            exp_dim             = motion_gen_params.exp_dim
        )
        self.input_dim = motion_gen_params.input_dim
        self.exp_dim = motion_gen_params.exp_dim

        self.audio_feat_dim = audio_proj_params.audio_feat_dim 
        self.audio_seq_len = audio_proj_params.sequence_length
        self.audio_blocks = audio_proj_params.blocks
        self.audio_margin = (audio_proj_params.sequence_length - 1) // 2
        self.indices = (
            torch.arange(2 * self.audio_margin + 1) - self.audio_margin
        ).unsqueeze(0)  # Generates [-2, -1, 0, 1, 2], size 1 x (2*self.audio_margin+1)
        
        self.n_prev_frames = motion_gen_params.n_prev_frames
        self.n_pred_frames = motion_gen_params.n_pred_frames
        
        # init diffusion schedule
        self.scheduler = scheduler_map[scheduler_config.type](
            num_train_timesteps = scheduler_params.num_train_timesteps,
            beta_start          = scheduler_params.beta_start, 
            beta_end            = scheduler_params.beta_end, 
            beta_schedule       = scheduler_params.mode,
            prediction_type     = scheduler_config.sample_mode,
            time_shifting       = scheduler_params.time_shifting,
        )
        self.scheduler_type = scheduler_config.type
        self.eta = scheduler_params.eta
        self.scheduler.set_timesteps(scheduler_params.num_inference_steps, device=self.device)
        self.timesteps = self.scheduler.timesteps
        print(f"time steps: {self.timesteps}")
        
        self.sample_mode = scheduler_config.sample_mode
        assert (self.sample_mode in ["noise", "sample"], f"Unknown sample mode {self.sample_mode}, should be noise or sample")

        # init other params
        self.audio_drop_ratio = config.train.audio_drop_ratio
        self.pre_drop_ratio = config.train.pre_drop_ratio

        # learnable null audio & motion embedding
        self.null_audio_feat = nn.Parameter(
            torch.randn(1, 1, 1, 1, self.audio_feat_dim), 
            requires_grad=True
        ).to(device=self.device, dtype=dtype)

        self.null_motion_feat = nn.Parameter(
            torch.randn(1, 1, self.input_dim),
            requires_grad=True
        ).to(device=self.device, dtype=dtype)
        
        # for segments fusion
        self.overlap_len = min(16, self.n_pred_frames - 16)
        self.fuse_alpha = torch.arange(self.overlap_len, device=self.device, dtype=dtype).reshape(1, -1, 1) / self.overlap_len

        self.dtype = dtype
        self.loss_type = loss_type

        total_params = sum(p.numel() for p in self.parameters())
        print('Number of parameter: % .4fM' % (total_params / 1e6))
        print(f"================================== init Motion GeneratorV2: Done ==================================")
        
    def drop_feature(self, audio, motion):
        # audio, B, T, L, b, D
        batch_size = audio.shape[0]
        # drop audio features for classifier-guidance
        p_a = torch.rand(batch_size, device=self.device)
        mask_audio = p_a < self.audio_drop_ratio
        audio = torch.where(
            mask_audio.view(-1, 1, 1, 1, 1),
            self.null_audio_feat.expand_as(audio),
            audio
        )

        # drop motion
        p_m = torch.rand(batch_size, device=self.device)
        mask_pre = p_m < self.pre_drop_ratio
        motion = torch.where(
            mask_pre.view(-1, 1, 1), 
            self.null_motion_feat.expand(motion.shape[0], -1, -1), 
            motion
        )

        return audio, motion
    
    def forward(self, motion, audio, ref_kp, loss_weight, mask,emo=None, timesteps=None):
        """
            ref_kp (torch.Tensor), B, L, kD
        """
        # for training
        #gt_motion = motion[:, self.n_prev_frames:].clone().float()
        cur_motion = motion[:, self.n_prev_frames:]
        cur_mask=mask[:, self.n_prev_frames:]
        prev_motion = motion[:, self.n_prev_frames-1:self.n_prev_frames]
        audio_sync_feat=audio[:, self.n_prev_frames:].clone()
        ###################################
        ##         preprare input        ##
        ###################################
        # randomly drop audio for classifier-guidance inference
        audio, prev_motion = self.drop_feature(audio, prev_motion)
        # prev_motion = motion[:, self.n_prev_frames-1:self.n_prev_frames]
        audio = audio[:, self.n_prev_frames:]
        ref_kp = ref_kp[:, self.n_prev_frames:]
        ###################################
        ##            add noise          ##
        ###################################
        noisy_cur_motion_feat, noise, gt, timesteps = self.scheduler.add_noise(cur_motion, timesteps=timesteps)

        ###################################
        ##          predict noise        ##
        ###################################
        prev_motion = prev_motion.expand_as(noisy_cur_motion_feat)
        motion_inputs = torch.cat([prev_motion, noisy_cur_motion_feat], dim=-1)  # B, T, 2D
        pred= self.talking_head_dit(
            motion     = motion_inputs, 
            times       = timesteps,
            audio      = audio,
            emo        = emo,
            audio_cond = ref_kp,
            mask       =None,
        )
        pred=pred.float()
        ###################################
        ##          calc losses          ##
        ###################################
        if self.scheduler_type == "flow_matching":
            smooth_input = self.scheduler.get_pred_original_sample(pred, timesteps, noisy_cur_motion_feat)
        else:
            smooth_input = pred
        sync_sample={
            "audio": audio_sync_feat,
            "motion": smooth_input,
            "y_mask":cur_mask,}
        # exp losses
        # print(loss_weight.shape)
        cur_mask=cur_mask.unsqueeze(-1).to(device=loss_weight.device)
        loss_weight=loss_weight*cur_mask
        # print(loss_weight.shape)
        gt_exp = gt[:, :, :self.exp_dim]
        exp_loss_weight = loss_weight[:, :, :self.exp_dim]
        pred_exp = pred[:, :, :self.exp_dim]
        smooth_exp = smooth_input[:, :, :self.exp_dim]
        # mask_exp=cur_mask[:, :, :self.exp_dim]
        exp_losses = { "denoise": 0.0, "velocity": 0.0, "smooth": 0.0}
        exp_losses["denoise"] = mse_loss(
            pred_exp, gt_exp, 
            loss_weight=exp_loss_weight, 
            loss_type=self.loss_type,
        )
        exp_losses["velocity"] = velocity_loss(
            pred_exp, gt_exp, 
            loss_weight=exp_loss_weight, 
            loss_type=self.loss_type
        )
        exp_losses["smooth"] = smooth_loss(
            smooth_exp, 
            loss_weight=exp_loss_weight, 
            loss_type=self.loss_type
        )

        disentagle_loss={'kp_loss':0.0, 'emo_loss':0.0}
        # pose losses
        gt_pose = gt[:, :, self.exp_dim:]
        pose_loss_weight = loss_weight[:, :, self.exp_dim:]
        pred_pose = pred[:, :, self.exp_dim:]
        smooth_pose = smooth_input[:, :, self.exp_dim:]
        
        pose_losses = { "denoise": 0.0, "velocity": 0.0, "smooth": 0.0}
        pose_losses["denoise"] = mse_loss(
            pred_pose, gt_pose, 
            loss_weight=pose_loss_weight,
            loss_type=self.loss_type
        )
        pose_losses["velocity"] = velocity_loss(
            pred_pose, gt_pose, 
            loss_weight=pose_loss_weight, 
            loss_type=self.loss_type
        )
        pose_losses["smooth"] = smooth_loss(
            smooth_pose, 
            loss_weight=pose_loss_weight, 
            loss_type=self.loss_type
        )

        # lip losses
        gt_lip = gt[:, :, :self.exp_dim][:,:,lip_dims]
        lip_loss_weight = loss_weight[:, :, :self.exp_dim][:,:,lip_dims]
        pred_lip = pred[:, :, :self.exp_dim][:,:,lip_dims]
        smooth_lip = smooth_input[:, :, :self.exp_dim][:,:,lip_dims]
        lip_losses = { "denoise": 0.0, "velocity": 0.0, "smooth": 0.0}
        lip_losses["denoise"] = mse_loss(
            pred_lip, gt_lip, 
            loss_weight=lip_loss_weight, 
            loss_type=self.loss_type,
        )
        lip_losses["velocity"] = velocity_loss(
            pred_lip, gt_lip, 
            loss_weight=lip_loss_weight, 
            loss_type=self.loss_type
        )
        lip_losses["smooth"] = smooth_loss(
            smooth_lip, 
            loss_weight=lip_loss_weight, 
            loss_type=self.loss_type
        )



        return exp_losses, pose_losses,sync_sample,disentagle_loss,lip_losses

    def _smooth(self, motion):
        # motion, B x L x D
        if self.smo_wsize <= 1:
            return motion
        new_motion = motion.clone()
        n = motion.shape[1]
        half_k = self.smo_wsize // 2
        for i in range(n):
            ss = max(0, i - half_k)
            ee = min(n, i + half_k + 1)
            # only smooth head pose motion
            motion[:, i, self.exp_dim:] = torch.mean(new_motion[:, ss:ee, self.exp_dim:], dim=1)
            
        return motion

    def _fuse(self, prev_motion, cur_motion):
        r1 = prev_motion[:, -self.overlap_len:]
        r2 = cur_motion[:, :self.overlap_len]
        r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha

        prev_motion[:, -self.overlap_len:] = r_fuse    # fuse last
        return prev_motion
    
    @torch.no_grad()
    def sample_subclip(
        self, 
        audio, 
        ref_kp,
        prev_motion,
        emo=None,
        cfg_scale=1.15, 
        init_latents=None,
        dynamic_threshold = None
    ):
        # prepare audio feat
        batch_size = audio.shape[0]
        audio = audio.to(self.device)
        if audio.ndim == 4:
            audio = audio.unsqueeze(2)
        
        # reference keypoints
        ref_kp = ref_kp.view(batch_size, 1, -1)
        
        # cfg
        if cfg_scale > 1:
            uncond_audio = self.null_audio_feat.expand(
               batch_size, self.n_pred_frames, self.audio_seq_len, self.audio_blocks, -1
            )
            audio = torch.cat([uncond_audio,audio], dim=0)
            ref_kp = torch.cat([ref_kp] * 2, dim=0)
            if emo is not None:
                uncond_emo = torch.Tensor([self.talking_head_dit.num_emo_class]).long().to(self.device)
                emo = torch.cat([uncond_emo,emo], dim=0)
        ref_kp = ref_kp.repeat(1, audio.shape[1], 1)  # B, L, kD

        # prepare noisy motion
        if init_latents is None:
            latents = torch.randn((batch_size, self.n_pred_frames, self.input_dim)).to(self.device)
        else:
            latents = init_latents
        
        prev_motion = prev_motion.expand_as(latents).to(dtype=self.dtype)
        latents = latents.to(dtype=self.dtype)
        audio = audio.to(dtype=self.dtype)
        ref_kp = ref_kp.to(dtype=self.dtype)

        #with tqdm(total=len(self.timesteps), desc="Denosing") as pbr:
        for t in track(self.timesteps, description='🚀Denosing', total=len(self.timesteps)):
            motion_in = torch.cat([prev_motion, latents], dim=-1)
            step_in = torch.tensor([t] * batch_size, device=self.device, dtype=self.dtype)
            if cfg_scale > 1:
                motion_in = torch.cat([motion_in] * 2, dim=0)
                step_in = torch.cat([step_in] * 2, dim=0)
            # predict
            pred = self.talking_head_dit(
                motion     = motion_in, 
                times       = step_in,
                audio      = audio,
                emo        = emo,
                audio_cond = ref_kp
            )
            if dynamic_threshold:
                dt_ratio, dt_min, dt_max = dynamic_threshold
                abs_results = pred.reshape(batch_size * 2, -1).abs()
                s = torch.quantile(abs_results, dt_ratio, dim=1)
                s = torch.clamp(s, min=dt_min, max=dt_max)
                s = s[..., None, None]
                pred = torch.clamp(pred, min=-s, max=s)

            # CFG
            if cfg_scale > 1:
                uncond_pred, cond_pred = pred.chunk(2, dim=0)
                pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
            # Step
            latents = self.scheduler.step(pred, t, latents, eta=self.eta, return_dict=False)[0]
        self.talking_head_dit.bank=[]
        return latents
            
    @torch.no_grad()
    def sample(self, audio, ref_kp, prev_motion, cfg_scale=1.15, audio_pad_mode="zero", emo=None,dynamic_threshold=None):
        # prev_motion, B, 1, D
        # for inference with any length audio
        # crop audio into n_subdivision according to n_pred_frames 
        clip_len = audio.shape[0]
        stride = self.n_pred_frames - self.overlap_len
        if clip_len <= self.n_pred_frames:
            n_subdivision = 1
        else:
            n_subdivision = math.ceil((clip_len - self.n_pred_frames) / stride) + 1
        
        # padding
        n_padding_frames = self.n_pred_frames + stride * (n_subdivision - 1) - clip_len
        if n_padding_frames > 0:
            padding_value = 0
            if audio_pad_mode == 'zero':
                padding_value = torch.zeros_like(audio[-1:])
            elif audio_pad_mode == 'replicate':
                padding_value = audio[-1:]
            else:
                raise ValueError(f'Unknown pad mode: {audio_pad_mode}')
            audio = torch.cat(
                [audio[:1]] * self.audio_margin \
                + [audio] + [padding_value] * n_padding_frames \
                + [audio[-1:]] * self.audio_margin, 
                dim=0
            )
        
        center_indices = torch.arange(
            self.audio_margin,
            audio.shape[0] - self.audio_margin
        ).unsqueeze(1) + self.indices
        audio_tensor = audio[center_indices]   # T, L, b, aD

        # add reference keypoints
        motion_lst = []
        #init_latents = torch.randn((1, self.n_pred_frames, self.motion_dim)).to(device=self.device)
        init_latents = None
        # emotion label
        if emo is not None:
            emo = torch.Tensor([emo]).long().to(self.device)
        start_idx = 0
        for i in range(0, n_subdivision):
            print(f"Sample subclip {i+1}/{n_subdivision}")
            end_idx = start_idx + self.n_pred_frames
            audio_segment = audio_tensor[start_idx: end_idx].unsqueeze(0)
            start_idx += stride

            # debug
            #print(f"scale:")
            
            motion_segment = self.sample_subclip(
                audio             = audio_segment, 
                ref_kp            = ref_kp,
                prev_motion       = prev_motion,
                emo               = emo,
                cfg_scale         = cfg_scale,
                init_latents      = init_latents,
                dynamic_threshold = dynamic_threshold
            )
            # smooth

            motion_segment = self._smooth(motion_segment)
            # update prev motion
            prev_motion = motion_segment[:, stride-1:stride].clone()

            # save results
            motion_coef = motion_segment
            if i == n_subdivision - 1 and n_padding_frames > 0:
                motion_coef = motion_coef[:, :-n_padding_frames]  # delete padded frames
            
            if len(motion_lst) > 0:
                # fuse segments
                motion_lst[-1] = self._fuse(motion_lst[-1], motion_coef)
                motion_lst.append(motion_coef[:, self.overlap_len:])
            else:
                motion_lst.append(motion_coef)
                
        motion = torch.cat(motion_lst, dim=1)
        # smooth for full clip
        motion = self._smooth(motion)
        motion = motion.squeeze()
        return motion.float()
    