import os
from typing import Dict
import math

import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
import robomimic.utils.obs_utils as ObsUtils
import robomimic.models.obs_core as rmbn
import diffusion_policy.model.vision.crop_randomizer as dmvc
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from diffusion_policy.dataset.pusht_image_dynamics_dataset import DynamicsModelDataset
from diffusion_policy.model.OOD.knn import KNN_torch, MahalanobisOODModule
from diffusion_policy.model.OOD.svm import SVMOODModule, train_ocsvm
from diffusion_policy.model.dynamics.dynamics_model import CVAE, HybridDynamicsModel, ImageBasedDynamicsModel, StateBasedDynamicsModel, StateBasedEncoderDecoderDynamicsModel, StateEncodeBasedDynamicsModel


def boundary_penalty(action, lower_bound=-1.0, upper_bound=1.0):
    penalty = torch.relu(action - upper_bound) + torch.relu(lower_bound - action)
    return penalty.sum()


class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,
            horizon, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            obs_as_global_cond=True,
            crop_shape=(76, 76),
            diffusion_step_embed_dim=256,
            down_dims=(256,512,1024),
            kernel_size=5,
            n_groups=8,
            cond_predict_scale=True,
            obs_encoder_group_norm=False,
            eval_fixed_crop=False,
            # parameters passed to step
            **kwargs):
        super().__init__()

        # parse shape_meta
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        obs_shape_meta = shape_meta['obs']
        obs_config = {
            'low_dim': [],
            'rgb': [],
            'depth': [],
            'scan': []
        }
        obs_key_shapes = dict()
        for key, attr in obs_shape_meta.items():
            shape = attr['shape']
            obs_key_shapes[key] = list(shape)

            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                obs_config['rgb'].append(key)
            elif type == 'low_dim':
                obs_config['low_dim'].append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        # get raw robomimic config
        config = get_robomimic_config(
            algo_name='bc_rnn',
            hdf5_type='image',
            task_name='square',
            dataset_type='ph')
        
        with config.unlocked():
            # set config with shape_meta
            config.observation.modalities.obs = obs_config

            if crop_shape is None:
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality['obs_randomizer_class'] = None
            else:
                # set random crop parameter
                ch, cw = crop_shape
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality.obs_randomizer_kwargs.crop_height = ch
                        modality.obs_randomizer_kwargs.crop_width = cw

        # init global state
        ObsUtils.initialize_obs_utils_with_config(config)

        # load model
        policy: PolicyAlgo = algo_factory(
                algo_name=config.algo_name,
                config=config,
                obs_key_shapes=obs_key_shapes,
                ac_dim=action_dim,
                device='cpu',
            )

        obs_encoder = policy.nets['policy'].nets['encoder'].nets['obs']
        
        if obs_encoder_group_norm:
            # replace batch norm with group norm
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=x.num_features//16, 
                    num_channels=x.num_features)
            )
            # obs_encoder.obs_nets['agentview_image'].nets[0].nets
        
        # obs_encoder.obs_randomizers['agentview_image']
        if eval_fixed_crop:
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
                func=lambda x: dmvc.CropRandomizer(
                    input_shape=x.input_shape,
                    crop_height=x.crop_height,
                    crop_width=x.crop_width,
                    num_crops=x.num_crops,
                    pos_enc=x.pos_enc
                )
            )

        # create diffusion model
        obs_feature_dim = obs_encoder.output_shape()[0]
        input_dim = action_dim + obs_feature_dim
        global_cond_dim = None
        if obs_as_global_cond:
            input_dim = action_dim
            global_cond_dim = obs_feature_dim * n_obs_steps

        model = ConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            cond_predict_scale=cond_predict_scale
        )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if obs_as_global_cond else obs_feature_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.dynamics_model_normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_global_cond = obs_as_global_cond
        self.kwargs = kwargs
        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

        print("Diffusion params: %e" % sum(p.numel() for p in self.model.parameters()))
        print("Vision params: %e" % sum(p.numel() for p in self.obs_encoder.parameters()))
    
    def set_ood_quantification(self, 
                               demo_data_file_path, 
                               demo_data_val_ratio, 
                               play_data_file_path,
                               dynamics_model_type, 
                               dynamics_model_ckpt, 
                               ood_method='vanila_nn'):
        dynamics_dataset = DynamicsModelDataset(play_data_file_path, val_ratio=0.0, random=False)
        demo_dataset = DynamicsModelDataset(demo_data_file_path, val_ratio=demo_data_val_ratio, random=False)
        # for convenience
        self.normalizer = demo_dataset.get_normalizer()
        self.normalizer = self.normalizer.to('cuda') if torch.cuda.is_available() else self.normalizer
        train_loader = DataLoader(demo_dataset, batch_size=32, shuffle=False, num_workers=4)
        train_latents = []
        images = []
        normalizer = dynamics_dataset.get_normalizer()
        print('Loaded dynamics_model_normalizer max', normalizer.params_dict['state']['input_stats']['max'])
        print('Loaded dynamics_model_normalizer min', normalizer.params_dict['state']['input_stats']['min'])
        # exit()
        self.dynamics_model_normalizer = normalizer.to('cuda') if torch.cuda.is_available() else normalizer
        # self.dynamics_model_type = 'state_encoder_decoder'
        directory_path = os.path.dirname(dynamics_model_ckpt)
        dynamics_config_path = os.path.join(directory_path, 'config.yaml')
        # dynamics_config_path = dynamics_model_ckpt.replace('best_val_model.pth', 'config.yaml')
        print('dynamics_config_path ', dynamics_config_path)
        if not os.path.exists(dynamics_config_path):
            raise FileNotFoundError(f"{dynamics_config_path} does not exist!")
        with open(dynamics_config_path, 'r') as f:
            print('debug here')
            dynamics_config = yaml.safe_load(f)
            print('debug here ', dynamics_config)
        hidden_dim = dynamics_config.get('_content', {}).get('value', {}).get('hidden_dim', None)
        name = dynamics_config.get('_content', {}).get('value', {}).get('wandb_dir', None)
        hidden_dim = int(hidden_dim)

        self.dynamics_model_type = dynamics_model_type
        if self.dynamics_model_type == 'image':
            self.dynamics_model = ImageBasedDynamicsModel(action_dim=2, latent_size=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        elif self.dynamics_model_type  == 'state':
            self.dynamics_model = StateBasedDynamicsModel(state_dim=6, action_dim=2, hidden_dim=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        elif self.dynamics_model_type == 'state_encoder_decoder':
            self.dynamics_model = StateBasedEncoderDecoderDynamicsModel(state_dim=6, action_dim=2, hidden_dim=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        elif self.dynamics_model_type  == 'state_encoder':
            # currently i'm using this branch
            if 'vae' not in name:
                self.dynamics_model = StateEncodeBasedDynamicsModel(state_dim=6, action_dim=2, hidden_dim=hidden_dim).to('cuda' if torch.cuda.is_available() else 'cpu')
            else:
                self.dynamics_model = CVAE(state_dim=6, action_dim=2, latent_dim=12).to('cuda' if torch.cuda.is_available() else 'cpu')
        
        elif self.dynamics_model_type  == 'hybrid':
            self.dynamics_model = HybridDynamicsModel(state_dim=6, action_dim=2, latent_size=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            raise ValueError(f"Invalid model_type: {self.dynamics_model_type}")
        
        # state based
        checkpoint_path = "outputs/2024-10-13/00-42-44/wandb/latest-run/files/best_val_model.pth"  # Example path
        # image based (for debugging) 
        # checkpoint_path = 'outputs/2024-10-10/19-15-31/wandb/run-20241010_191535-i9f7n1q7/files/best_train_model.pth'
        checkpoint = torch.load(dynamics_model_ckpt, map_location='cuda')
        self.dynamics_model.load_state_dict(checkpoint['model_state_dict'])

        for batch_idx, batch in enumerate(train_loader):
            img_t = batch['o_t']['image'].to(device='cuda')
            images.append(img_t)
            img_t = self.dynamics_model_normalizer['image'].normalize(img_t)

            state_t = batch['o_t']['state'].to(device='cuda')
            state_t = self.dynamics_model_normalizer['state'].normalize(state_t)

            z_t = None
            if self.dynamics_model_type == 'image':
                z_t = self.dynamics_model.encode(img_t)
            elif 'state' in self.dynamics_model_type:
                z_t = self.dynamics_model.encode(state_t)
            elif self.dynamics_model_type  == 'hybrid':
                z_t = self.dynamics_model.encode(img_t, state_t)
            else:
                raise ValueError(f"Invalid model_type: {self.dynamics_model_type}")
            
            if type(z_t) == tuple:
                z_t = z_t[0]
            train_latents.append(z_t)
            # ########## saving image ##########
            # Unnormalize (if needed) and convert to PIL image format
            # folder_path = 'visualizations/image/'
            # os.makedirs(folder_path, exist_ok=True)
            # for i in range(img_t.size(0)):
            #     image = img_t[i]  # Select the i-th image in the batch
            #     save_path = os.path.join(folder_path, f'image_{i}.png')
            #     save_image(image, save_path)
            #     print(f'Saved {save_path}')
            # ########## saving finished ##########
            # ########## check model loading ##########
            # criterion = nn.MSELoss()    
            # a_t = batch['action'].to(device='cuda')
            # a_t = self.dynamics_model_normalizer['action'].normalize(a_t)
            # z_t1_pred, z_t1_true = None, None
            # if self.dynamics_model_type == 'image':
            #     img_t1 = batch['o_t1']['image'].to(device='cuda')
            #     img_t1 = self.dynamics_model_normalizer['image'].normalize(img_t1)
            #     z_t, z_t1_pred = self.dynamics_model(img_t, a_t)
            #     z_t1_true = self.dynamics_model.encode(img_t1)
            # elif self.dynamics_model_type  == 'state':
            #     state_t1 = batch['o_t1']['state'].to(device='cuda')
            #     state_t1 = self.dynamics_model_normalizer['state'].normalize(state_t1)
            #     z_t, z_t1_pred = self.dynamics_model(state_t, a_t)
            #     z_t1_true = self.dynamics_model.encode(state_t1)
            # elif self.dynamics_model_type  == 'hybrid':
            #     img_t1 = batch['o_t1']['image'].to(device='cuda')
            #     img_t1 = self.dynamics_model_normalizer['image'].normalize(img_t1)
            #     state_t1 = batch['o_t1']['state'].to(device='cuda')
            #     state_t1 = self.dynamics_model_normalizer['state'].normalize(state_t1)
            #     z_t, z_t1_pred = self.dynamics_model(img_t, state_t, a_t)
            #     z_t1_true = self.dynamics_model.encode(img_t1, state_t1)

            # # Compute loss
            # loss = criterion(z_t1_pred, z_t1_true)
            # print('loss ', loss)
            # ########## check finished ##########
        train_latents = torch.cat(train_latents, dim=0)  # Combine all latents into a single tensor
        images = torch.cat(images, dim=0)  # Combine all images into a single tensor

        self.ood_module = KNN_torch(train_latents=train_latents, images=images, method=ood_method).to('cuda')
        # elif method == 'mahalanobis':
        #     mean_vector = train_latents.mean(dim=0)
        #     covariance_matrix = torch.cov(train_latents.T)
        #     covariance_matrix += torch.eye(covariance_matrix.size(0)).to(device='cuda') * 1e-6
        #     self.ood_module = MahalanobisOODModule(mean_vector, covariance_matrix).to('cuda')

        self.dynamics_model.eval()

    # ========= inference  ============
    def conditional_sample(self, 
            condition_data, condition_mask,
            local_cond=None, global_cond=None,
            generator=None,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t, 
                local_cond=local_cond, global_cond=global_cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample
        
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory
                

    def predict_action(self, obs_dict: Dict[str, torch.Tensor], avoid_ood=False, optim_lr=1e-4, h_step=1, num_ood_steps=1, weight_decay=0, input_type='state') -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        with torch.no_grad():
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                **self.kwargs)
        # unnormalize prediction
        naction_pred = nsample[...,:Da]

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = naction_pred[:,start:end]

        def print_grad(grad, name):
            print(f"Gradient of {name}: {grad}")

        if avoid_ood:
            # normalize action with normalizer from trained dynamics model
            action_unormalized = self.normalizer['action'].unnormalize(action)
            action_copy = self.dynamics_model_normalizer['action'].normalize(action_unormalized)
            # action_copy = action.clone()
            action_copy.requires_grad = True
            print(' ------------------ ')
            print('action before ', action_copy[:3])
            optimizer = optim.Adam([action_copy], lr=optim_lr)
            # optimizer = optim.SGD([action_copy], lr=1e-2, momentum=0.99)
            # normalize action with normalizer from trained dynamics model
            initial_obs_dict = dict_apply(obs_dict, lambda x: x[:, -1, ...])
            current_img = initial_obs_dict['image'].detach().cpu().numpy()
            initial_obs = self.dynamics_model_normalizer.normalize(initial_obs_dict)

            initial_obs_latent = None
            if self.dynamics_model_type == 'image':
                initial_img = initial_obs['image']
                initial_obs_latent = self.dynamics_model.encode(initial_img)
            elif 'state' in self.dynamics_model_type:
                initial_img = initial_obs['image']
                initial_state = initial_obs['state']
                initial_obs_latent = self.dynamics_model.encode(initial_state)
            elif self.dynamics_model_type == 'hybrid':
                initial_img = initial_obs['image']
                initial_state = initial_obs['state']
                initial_obs_latent = self.dynamics_model.encode(initial_img, initial_state)

            if type(initial_obs_latent) == tuple:
                initial_obs_latent = initial_obs_latent[0]
            # Register a hook on the initial state
            # initial_state.register_hook(lambda grad: print_grad(grad, "initial_state"))

            num_iters = 20
            if not avoid_ood:
                num_iters = 0
            total_loss_array = []
            for _ in range(num_iters):
                optimizer.zero_grad()
                total_loss = 0
                total_ood_score = 0
                # Rollout the dynamics model over multiple steps
                current_latent_obs = initial_obs_latent.clone()

                # predicted_states = []
                for i in range(self.n_action_steps):
                    current_action = action_copy[:, i, :]
                    current_latent_obs = self.dynamics_model.latent_transition(current_latent_obs, current_action)
                    # Accumulate the total OOD score
                    if i == self.n_action_steps - 1:
                    # if True:
                        ood_score = self.ood_module(current_latent_obs, current_img)
                        # only regularizing the last state?
                        total_ood_score += ood_score.sum()
                # Register a hook on the final loss
                print(f"Total OOD score at step {_}: {total_ood_score}")

                # apply boundary penalty to updated action
                boundary_penalty_loss = boundary_penalty(action_copy)
                total_loss = total_ood_score + 0.1 * boundary_penalty_loss 

                if total_loss == 0.0:
                    continue
                else:              
                    total_loss.backward(retain_graph=True)

                optimizer.step()
                total_loss_array.append(total_loss.item())
                if len(total_loss_array) > 1 and abs(total_loss_array[-1] - total_loss_array[-2]) < 1e-4:
                    with torch.no_grad():
                        action_copy += torch.randn_like(action_copy) * 5e-3
                    print('action grad ', action_copy.grad)
                        # action_copy += torch.randn_like(action_copy) * 3e-2
                # if len(total_loss_array) > 1 and abs(total_loss_array[-1]) < 0.003:
                #     break
                with torch.no_grad():
                    action_copy.clamp_(-1, 1)

            print('action after ', action_copy[:3])
            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            action_copy = self.dynamics_model_normalizer['action'].unnormalize(action_copy)

            result = {
                'action': action_copy,
                'action_pred': action_pred
            }
            return result
        else:
            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            action = self.normalizer['action'].unnormalize(action)
            result = {
                'action': action,
                'action_pred': action_pred
            }
            return result
        
    def save_videos(self,):
        self.ood_module.save_videos()
        
    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = nactions
        cond_data = trajectory
        if self.obs_as_global_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, 
                lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(batch_size, -1)
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            cond_data = torch.cat([nactions, nobs_features], dim=-1)
            trajectory = cond_data.detach()

        # generate impainting mask
        condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)
        
        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = cond_data[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, 
            local_cond=local_cond, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss