from typing import Dict
import math
import cv2
import kornia.losses
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
import copy
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
import robomimic.utils.obs_utils as ObsUtils
import robomimic.models.base_nets as rmbn
import diffusion_policy.model.vision.crop_randomizer as dmvc
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
import torchvision.models as models
from train_cifar10 import UNet


class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,
            horizon, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            obs_as_global_cond=True,
            crop_shape=(76, 76),
            diffusion_step_embed_dim=256,
            down_dims=(256,512,1024),
            kernel_size=5,
            n_groups=8,
            cond_predict_scale=True,
            obs_encoder_group_norm=False,
            eval_fixed_crop=False,
            # parameters passed to step
            **kwargs):
        super().__init__()

        # parse shape_meta
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        obs_shape_meta = shape_meta['obs']
        obs_config = {
            'low_dim': [],
            'rgb': [],
            'depth': [],
            'scan': []
        }
        obs_key_shapes = dict()
        for key, attr in obs_shape_meta.items():
            shape = attr['shape']
            obs_key_shapes[key] = list(shape)

            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                obs_config['rgb'].append(key)
            elif type == 'low_dim':
                obs_config['low_dim'].append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        # get raw robomimic config
        config = get_robomimic_config(
            algo_name='bc_rnn',
            hdf5_type='image',
            task_name='square',
            dataset_type='ph')
        
        with config.unlocked():
            # set config with shape_meta
            config.observation.modalities.obs = obs_config

            if crop_shape is None:
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality['obs_randomizer_class'] = None
            else:
                # set random crop parameter
                ch, cw = crop_shape
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality.obs_randomizer_kwargs.crop_height = ch
                        modality.obs_randomizer_kwargs.crop_width = cw

        # init global state
        ObsUtils.initialize_obs_utils_with_config(config)

        # load model
        policy: PolicyAlgo = algo_factory(
                algo_name=config.algo_name,
                config=config,
                obs_key_shapes=obs_key_shapes,
                ac_dim=action_dim,
                device='cpu',
            )

        obs_encoder = policy.nets['policy'].nets['encoder'].nets['obs']
        
        if obs_encoder_group_norm:
            # replace batch norm with group norm
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=x.num_features//16, 
                    num_channels=x.num_features)
            )
            # obs_encoder.obs_nets['agentview_image'].nets[0].nets
        
        # obs_encoder.obs_randomizers['agentview_image']
        if eval_fixed_crop:
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
                func=lambda x: dmvc.CropRandomizer(
                    input_shape=x.input_shape,
                    crop_height=x.crop_height,
                    crop_width=x.crop_width,
                    num_crops=x.num_crops,
                    pos_enc=x.pos_enc
                )
            )

        # create diffusion model
        obs_feature_dim = obs_encoder.output_shape()[0]
        input_dim = action_dim + obs_feature_dim
        global_cond_dim = None
        if obs_as_global_cond:
            input_dim = action_dim
            global_cond_dim = obs_feature_dim * n_obs_steps

        model = ConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            cond_predict_scale=cond_predict_scale
        )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if obs_as_global_cond else obs_feature_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_global_cond = obs_as_global_cond
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps
        self.n_img_count = 0
        self.index_list = []

        print("Diffusion params: %e" % sum(p.numel() for p in self.model.parameters()))
        print("Vision params: %e" % sum(p.numel() for p in self.obs_encoder.parameters()))
    
    # ========= inference  ============
    def conditional_sample_origin(self, 
            condition_data, condition_mask,
            local_cond=None, global_cond=None,
            generator=None,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t, 
                local_cond=local_cond, global_cond=global_cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample
        
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory

    # ========= inference  ============
    def conditional_sample(self, 
            obs_dict=None,
            generator=None,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None

        model = self.model
        scheduler = self.noise_scheduler
        if self.obs_as_global_cond:
            # empty data for action
            condition_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            condition_mask = torch.zeros_like(condition_data, dtype=torch.bool)
        else:
            condition_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            condition_mask = torch.zeros_like(condition_data, dtype=torch.bool)
            
        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        scheduler.set_timesteps(self.num_inference_steps)
        for t in scheduler.timesteps:
            if self.obs_as_global_cond:
            # condition through global feature
                this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
                nobs_features = self.obs_encoder(this_nobs)
                # reshape back to B, Do
                global_cond = nobs_features.reshape(B, -1)
                # mu, sigma = global_cond.mean(dim=1).unsqueeze(dim=1), global_cond.std(dim=1).unsqueeze(dim=1)+1e-6
                # noisy_cond = ((global_cond - mu) / sigma).to(trajectory.device)
            else:
                # condition through impainting
                this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
                nobs_features = self.obs_encoder(this_nobs,t)
                # reshape back to B, To, Do
                nobs_features = nobs_features.reshape(B, To, -1)
                condition_data[:,:To,Da:] = nobs_features
                condition_mask[:,:To,Da:] = True

            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output,cond_pred = model(trajectory, t, 
                local_cond=local_cond, global_cond=global_cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample
        
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory

    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        Da = self.action_dim
        To = self.n_obs_steps
        # run sampling
        nsample = self.conditional_sample(
            obs_dict=obs_dict,
            **self.kwargs)
        
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:,start:end]
        
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result

    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    
    def compute_loss(self, batch, target_obs, epoch=0, beta = 0.2, training=False, stop_epoch=150, scale=0.1):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]
        index = batch['index']
        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = nactions
        cond_data = trajectory

        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        if self.obs_as_global_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs,
                lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
            noisy_features = self.obs_encoder(this_nobs)
            global_cond = noisy_features.reshape(batch_size, -1)
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            cond_data = torch.cat([nactions, nobs_features], dim=-1)
            trajectory = cond_data.detach()
        
        if training and epoch<=stop_epoch:
            t0_cond = torch.randn(target_obs[index].shape, device=trajectory.device) * 0.5
            mu, sigma = global_cond.mean(dim=1).unsqueeze(dim=1), global_cond.std(dim=1).unsqueeze(dim=1)+1e-6
            tn_cond = ((global_cond - mu) / sigma).to(trajectory.device)
            noisy_cond = self.noise_scheduler.add_noise(
                t0_cond, tn_cond, timesteps) * scale
        else:
            noisy_cond = global_cond
        
        # generate impainting mask
        condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)


        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = cond_data[condition_mask]
        
        # Predict the noise residual
        pred,pred_cond = self.model(noisy_trajectory, timesteps, 
            local_cond=local_cond, global_cond=noisy_cond)

        if not target_obs == None and training and epoch<=stop_epoch:
            target_obs[index] = (1 - beta) * target_obs[index] + beta * pred_cond.data.detach()

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")
        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        if training and epoch<=stop_epoch:
            loss_cond =  F.mse_loss(pred_cond, tn_cond.detach() * scale, reduction='none')
            loss = loss.mean() + loss_cond.mean()
        else:
            loss_cond =  F.mse_loss(noisy_cond, target_obs[index].detach(), reduction='none')
            loss = loss.mean() + loss_cond.mean()
        return loss
    

    def fft_loss(self, pred, target):
        pred_fft = torch.fft.fftn(pred, dim=(-2, -1)).abs()  # 计算 FFT 频谱
        target_fft = torch.fft.fftn(target, dim=(-2, -1)).abs()
        return F.l1_loss(pred_fft, target_fft)

    def tensor_to_numpy(self,tensor):
        """ 将形状为 [3, H, W] 的 torch.Tensor 转换为 [H, W, 3] 的 numpy 数组 """
        tensor = tensor.detach().cpu().numpy()  # 先转 numpy
        tensor = np.transpose(tensor, (1, 2, 0))  # 从 (C, H, W) 转成 (H, W, C)
        tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * 255  # 归一化
        return tensor.astype(np.uint8)  # 转 uint8 方便处理

    def compute_fft(self,image):
        """ 计算图像的傅里叶变换并返回频谱图 """
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)  # 转灰度图
        f = np.fft.fft2(gray)  # 计算FFT
        fshift = np.fft.fftshift(f)  # 低频移动到中心
        magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1)  # 计算振幅谱
        return magnitude_spectrum

    def save_images(self,original_tensor, reconstructed_tensor):
        """ 计算并保存原图、重建图及其频域信号 """

        # 转换图像
        original_img = self.tensor_to_numpy(original_tensor)  # Tensor -> NumPy
        reconstructed_img = self.tensor_to_numpy(reconstructed_tensor)

        # 计算频谱
        fft_original = self.compute_fft(original_img)
        fft_reconstructed = self.compute_fft(reconstructed_img)

        # 归一化到 0-255 方便保存
        fft_original = cv2.normalize(fft_original, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        fft_reconstructed = cv2.normalize(fft_reconstructed, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

        # 保存频谱图
        cv2.imwrite("fft_original.png", fft_original)
        cv2.imwrite("fft_reconstructed.png", fft_reconstructed)
