from typing import Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
# from consistency_policy.diffusion_unet_with_dropout import ValueUnet1D
from diffusion_policy.model.diffusion.conditional_unet1d import ValueUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
import robomimic.utils.obs_utils as ObsUtils
import robomimic.models.base_nets as rmbn
import diffusion_policy.model.vision.crop_randomizer as dmvc
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules


class DiffusionUnetD4RLReward(BaseImagePolicy):
    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,
            horizon, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            obs_as_global_cond=True,
            diffusion_step_embed_dim=256,
            down_dims=(256,512,1024),
            dropout_rate=0.2,
            kernel_size=5,
            n_groups=8,
            cond_predict_scale=True,
            # parameters passed to step
            **kwargs):
        super().__init__()

        # parse shape_meta
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        
        obs_shape = shape_meta['observation']['shape']
        assert len(obs_shape) == 1
        obs_dim = obs_shape[0]

        # create reward model
        input_dim = action_dim + obs_dim

        model = ValueUnet1D(
            input_dim=input_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            # dropout_rate=dropout_rate
        )

        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if obs_as_global_cond else obs_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_global_cond = obs_as_global_cond
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

        print("Diffusion reward params: %e" % sum(p.numel() for p in self.model.parameters()))


    def predict_reward(self, trajectory) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        act_dict: must include "action" key
        result: must include "reward" key
        """
        # # normalize input
        # nobs = self.normalizer.normalize(obs_dict)
        # nactions = self.normalizer['action'].normalize(act_dict)
        # value = next(iter(nobs.values()))
        # B, To = value.shape[:2]
        # T = self.horizon
        # Da = self.action_dim
        # Do = self.obs_feature_dim
        # To = self.n_obs_steps

        # # build input
        # device = self.device
        # dtype = self.dtype

        # ## doesn't matter which type of conditioning, we just normalize obs and action and pass it to reward model

        # this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
        # nobs_features = self.obs_encoder(this_nobs)

        # # # reshape back to B, T, Do
        # # nobs_features = nobs_features.reshape(B, self.horizon, -1)
        # nactions = nactions.reshape(-1, *nactions.shape[2:])

        # # (batch_size, horizon, Da + Do)
        # cond_data = torch.cat([nactions, nobs_features], dim=-1)
        # trajectory = cond_data.detach()

        reward = self.model(trajectory)
        
        # reshape reward: 
        # reward = reward.reshape(B, self.horizon)
        
        result = {
            'reward': reward
        }
        return result

    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer['observations'].normalize(batch['observations'])
        nactions = self.normalizer['actions'].normalize(batch['actions'])
        ## TODO:  add normalizer to reward ?? 
        nrewards = self.normalizer['rewards'].normalize(batch['rewards'])
        # nrewards = batch['rewards']

        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]

        # handle different ways of passing observation
        # local_cond = None
        # global_cond = None
        # trajectory = nactions  ##(B, T, act_dim)
        # cond_data = trajectory

        this_nobs = nobs.reshape(-1, *nobs.shape[2:])
        # # reshape back to B, T, Do
        # nobs_features = nobs_features.reshape(batch_size, horizon, -1)

        nactions = nactions.reshape(-1, *nactions.shape[2:])

        #TODO what's the dimension of this data?
        # (batch_size * horizon, Da + Do)
        cond_data = torch.cat([nactions, this_nobs], dim=-1)
        trajectory = cond_data.detach()
        
        ## trajectory shape:  torch.Size([65536, 14])
        pred = self.model(trajectory)
        # reshape reward: 
        pred = pred.reshape(batch_size*horizon, 1)
        nrewards = nrewards.reshape(-1, 1)
        ### be careful of the dimension, check if need to unsqueeze
        ### torch.Size([128, 16])

        loss = F.mse_loss(pred, nrewards, reduction='none')
        # loss = loss * loss_mask.type(loss.dtype)
        # loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss
