from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler 
import random
import torch.fft as fft 
import pywt 
import math 
from torch.distributions import Normal
import numpy as np
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator



class DiffusionUnetLowdimPolicy(BaseLowdimPolicy):
    def __init__(self, 
            model: ConditionalUnet1D,
            noise_scheduler: DDIMScheduler,
            horizon, 
            obs_dim, 
            action_dim, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            obs_as_local_cond=False,
            obs_as_global_cond=False,
            pred_action_steps_only=False,
            oa_step_convention=False,
            **kwargs):
        super().__init__()
        assert not (obs_as_local_cond and obs_as_global_cond)
        if pred_action_steps_only:
            assert obs_as_global_cond
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if (obs_as_local_cond or obs_as_global_cond) else obs_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_local_cond = obs_as_local_cond
        self.obs_as_global_cond = obs_as_global_cond
        self.pred_action_steps_only = pred_action_steps_only
        self.oa_step_convention = oa_step_convention
        self.kwargs = kwargs
        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps
    
    
    def conditional_sample(self, 
            condition_data, condition_mask, 
            condition_data1, condition_mask1,          
            local_cond=None, local_cond1=None,
            global_cond=None, global_cond1=None,  
            generator=None,
            prior=None,
            weak = None,
            cond_data_null_shape = None,
            **kwargs
            ):


        model = self.model
        scheduler = self.noise_scheduler

        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        if prior is None:
            trajectory = torch.randn(
                size=condition_data.shape,
                dtype=condition_data.dtype,
                device=condition_data.device,
                generator=generator)
            bsz = trajectory[:,:self.horizon-self.n_action_steps,:self.action_dim].shape[0] 
            start_t = scheduler.timesteps[torch.randint(0, 1, (bsz,))]
 
        else:
            trajectory = prior

        kwargs.pop('timesteps', None) # cw

        trajectory = trajectory.to(device=condition_data.device).contiguous()
        trajectory1 = trajectory.clone()

        condition_data = condition_data.contiguous()
        condition_mask = condition_mask.to(dtype=torch.bool).contiguous()

        condition_data1 = condition_data1.contiguous()
        condition_mask1 = condition_mask1.to(dtype=torch.bool).contiguous()
        
        for t in scheduler.timesteps: 
            trajectory[condition_mask] = condition_data[condition_mask]
            trajectory1 = trajectory.clone()
            trajectory1[condition_mask1] = condition_data1[condition_mask1]
            
        
            with torch.no_grad():
                if weak : # for AutoGudiance
                    w = self.kwargs.get("alpha", 1.0)
                    model_output1 = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) 
                    model_output2 = weak.model(trajectory1, t, local_cond=local_cond, global_cond=global_cond) 
                    model_output = w * model_output1 + (1-w) * model_output2 
                    
                else:
                    model_output1 = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) 
                    w = self.kwargs.get("alpha", 0.0)
                    
                    if (trajectory1 == trajectory).all() \
                    and ((global_cond is None and global_cond1 is None) or (global_cond is not None and global_cond1 is not None and (global_cond == global_cond1).all())) \
                    and ((local_cond is None  and local_cond1 is None ) or (local_cond is not None  and local_cond1 is not None  and (local_cond == local_cond1).all())):
                        model_output = model_output1
                    
                    else:
                        model_output2 = model(trajectory1, t, local_cond=local_cond1, global_cond=global_cond1) 
                        model_output = w * (model_output1 - model_output2) + model_output1 

                
                
                
            kwargs.pop('alpha', None) 
            trajectory = scheduler.step(model_output, t, trajectory, generator=generator,**kwargs).prev_sample
            
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        
        return trajectory


    def predict_action(self, obs_dict: Dict[str, torch.Tensor], previous_obs_dict: Dict[str, torch.Tensor], weak=None) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'obs' in obs_dict
        assert 'past_action' not in obs_dict # not implemented yet
        
        nobs = self.normalizer['obs'].normalize(obs_dict['obs']) 
        nobs1 = self.normalizer['obs'].normalize(previous_obs_dict['obs']) # cw
    
        B, _, Do = nobs.shape
        To = self.n_obs_steps
        assert Do == self.obs_dim
        T = self.horizon
        Da = self.action_dim

        if 'prior' in obs_dict:
            prior = obs_dict['prior']
            
        else:
            prior = None
            
        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        local_cond1 = None
        global_cond1 = None        

        
        

        if self.obs_as_local_cond:
            local_cond = torch.zeros(size=(B,T,Do), device=device, dtype=dtype)
            local_cond[:,:To] = nobs[:,:To]
            
            local_cond1 = torch.zeros(size=(B,T,Do), device=device, dtype=dtype)
            local_cond1[:,:To] = nobs1[:,:To]    
            
            shape = (B, T, Da)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        
            cond_data1 = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask1 = torch.zeros_like(cond_data1, dtype=torch.bool)


        elif self.obs_as_global_cond:
            global_cond = nobs[:,:To].reshape(nobs.shape[0], -1)
            global_cond1 = nobs1[:,:To].reshape(nobs.shape[0], -1)            
  
            shape = (B, T, Da)
            if self.pred_action_steps_only:
                shape = (B, self.n_action_steps, Da)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        
            cond_data1 = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask1 = torch.zeros_like(cond_data1, dtype=torch.bool)

    
        else:
            # condition through impainting
            shape = (B, T, Da+Do)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)

            cond_data1 = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask1 = torch.zeros_like(cond_data, dtype=torch.bool)

            cond_data[:,:To,Da:] = nobs[:,:To]
            cond_data1[:,:To,Da:] = nobs1[:,:To]
            
            cond_mask[:,:To,Da:] = True
            cond_mask1[:,:To,Da:] = True


        # run sampling
        nsample = self.conditional_sample(
            cond_data, 
            cond_mask,
            cond_data1,
            cond_mask1,            
            local_cond=local_cond,
            local_cond1=local_cond1,
            global_cond=global_cond,
            global_cond1=global_cond1,
            prior=prior,
            weak = weak,
            cond_data_null_shape = [B,To,Do],
            **self.kwargs)


 
        
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        if self.pred_action_steps_only:
            action = action_pred
        else:
            start = To
            if self.oa_step_convention:
                start = To - 1
            end = start + self.n_action_steps
            action = action_pred[:,start:end]

        result = {
            'action': action,
            'action_pred': action_pred
        }
        if not (self.obs_as_local_cond or self.obs_as_global_cond):
            nobs_pred = nsample[...,Da:]
            obs_pred = self.normalizer['obs'].unnormalize(nobs_pred)
            action_obs_pred = obs_pred[:,start:end]
            result['action_obs_pred'] = action_obs_pred
            result['obs_pred'] = obs_pred
        return result

    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nbatch = self.normalizer.normalize(batch)
        obs = nbatch['obs']
        action = nbatch['action']

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = action
        if self.obs_as_local_cond:
            # zero out observations after n_obs_steps
            local_cond = obs
            local_cond[:,self.n_obs_steps:,:] = 0
        elif self.obs_as_global_cond:
            global_cond = obs[:,:self.n_obs_steps,:].reshape(
                obs.shape[0], -1)
            if self.pred_action_steps_only:
                To = self.n_obs_steps
                start = To
                if self.oa_step_convention:
                    start = To - 1
                end = start + self.n_action_steps
                trajectory = action[:,start:end]
        else:
            trajectory = torch.cat([action, obs], dim=-1)

        # generate impainting mask
        if self.pred_action_steps_only:
            condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
        else:
            condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        
     

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)
        
        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = trajectory[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, 
            local_cond=local_cond, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss
