from diffusion_policy.common.normalize_util import array_to_stats, get_identity_normalizer_from_stat, get_image_range_normalizer, get_range_normalizer_from_stat, robomimic_abs_action_only_normalizer_from_stat
from diffusion_policy.dataset.robomimic_replay_image_dataset import _convert_robomimic_to_replay
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from dino_wm.datasets.img_transforms import default_transform, get_train_crop_transform, get_eval_crop_transform, get_train_crop_transform_resnet, get_eval_crop_transform_resnet
from dino_wm.datasets.utils import ReplayBuffer
import torch
from filelock import FileLock
import os
from torch.utils.data import Dataset
import numpy as np
import copy
import shutil

import zarr
from einops import rearrange
import random
from torch.utils.data import DataLoader

class RobomimicImageDynamicsModelDataset(Dataset):
    def __init__(self, 
                 zarr_path, 
                 num_hist=1, 
                 num_pred=1, 
                 frameskip=8,
                 view_names=['agentview', 'robot0_eye_in_hand'],
                 abs_action=False,
                 use_crop=False,
                 train=True,
                 return_rewards=False,
                 encoder_type='dino',
                 action_conditioned_time_contrastive=False,):
        """
        Initializes the dataset by loading data from a Zarr file and precomputing valid anchor indices.
        
        Args:
            zarr_path (str): Path to the Zarr dataset.
            horizon (int): Number of steps for history and future.
            val_ratio (float): Fraction of episodes to use for validation.
            n_neg (int): Number of negative samples (unused in this implementation).
        """
        self.abs_action = abs_action
        if abs_action:
            action_dim = 10
            if 'transport' in zarr_path:
                action_dim = 20
        else:
            action_dim = 7
        self.original_action_dim = action_dim
        shape_meta = {'obs': 
                    {
                    'robot0_eef_pos': {'shape': [3]}, 
                    'robot0_eef_quat': {'shape': [4]}, 
                    'robot0_gripper_qpos': {'shape': [2]}
                    },
                    'action': {'shape': [action_dim]}
                }
        if 'square' in zarr_path:
            shape_meta['obs']['agentview_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
            shape_meta['obs']['robot0_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        elif 'tool_hang' in zarr_path:
            shape_meta['obs']['sideview_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
            shape_meta['obs']['robot0_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        elif 'transport' in zarr_path:
            shape_meta['obs']['robot0_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
            shape_meta['obs']['robot1_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
            shape_meta['obs']['shouldercamera0_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
            shape_meta['obs']['shouldercamera1_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
            shape_meta['obs']['robot1_eef_pos'] = {'shape': [3]}
            shape_meta['obs']['robot1_eef_quat'] = {'shape': [4]}
            shape_meta['obs']['robot1_gripper_qpos'] = {'shape': [2]}
        rotation_transformer = RotationTransformer(
            from_rep='axis_angle', to_rep='rotation_6d')
        
        cache_zarr_path = zarr_path + '_lossless.zarr.zip'
        cache_lock_path = cache_zarr_path + '_lossless.lock'
        print('Acquiring lock on cache.')
        with FileLock(cache_lock_path):
            if not os.path.exists(cache_zarr_path):
                # cache does not exists
                try:
                    print('Cache does not exist. Creating!')
                    replay_buffer = _convert_robomimic_to_replay(
                        store=zarr.MemoryStore(), 
                        shape_meta=shape_meta, 
                        dataset_path=zarr_path, 
                        abs_action=abs_action, 
                        rotation_transformer=rotation_transformer,
                        get_reward=return_rewards)
                    print('Saving cache to disk.')
                    with zarr.ZipStore(cache_zarr_path) as zip_store:
                        replay_buffer.save_to_store(
                            store=zip_store
                        )
                except Exception as e:
                    shutil.rmtree(cache_zarr_path)
                    raise e
            else:
                print('Loading cached ReplayBuffer from Disk.')
                print('cache_zarr_path ', cache_zarr_path)
                with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
                    replay_buffer = ReplayBuffer.copy_from_store(
                        src_store=zip_store, store=zarr.MemoryStore())
                print('Loaded!')
                
        # Extract episode ends (1-indexed)
        self.episode_ends = replay_buffer.episode_ends[:]
        # self.episode_start_indices = np.concatenate(([0], self.episode_ends[:-1]))
        # self.episode_end_indices = self.episode_ends - 1  # last index of each trajectory

        robot0_eef_pos = np.array(replay_buffer['robot0_eef_pos'])
        robot0_eef_quat = np.array(replay_buffer['robot0_eef_quat'])
        robot0_gripper_qpos = np.array(replay_buffer['robot0_gripper_qpos'])
        self.states = np.concatenate((robot0_eef_pos, robot0_eef_quat, robot0_gripper_qpos), axis=1)

        if 'transport' in zarr_path:
            robot1_eef_pos = np.array(replay_buffer['robot1_eef_pos'])
            robot1_eef_quat = np.array(replay_buffer['robot1_eef_quat'])
            robot1_gripper_qpos = np.array(replay_buffer['robot1_gripper_qpos'])
            self.states = np.concatenate((self.states, robot1_eef_pos, robot1_eef_quat, robot1_gripper_qpos), axis=1)

        self.return_rewards = return_rewards
        self.rewards = None
        if return_rewards:
            self.rewards = np.array(replay_buffer['rewards_closeness_to_demo'])
        # Extract and process states and actions
        self.view_names = view_names
        self.imgs = {}
        for view_name in self.view_names:
            self.imgs[view_name] = np.array(replay_buffer[view_name])
            print(f'self.imgs[{view_name}] ', self.imgs[view_name].shape)
            original_img_size = self.imgs[view_name].shape[1]
            assert original_img_size == 140
            
        self.states_dim = self.states.shape[1]
        self.proprio_dim = self.states.shape[1]
        self.action_dim = self.original_action_dim * frameskip
        if self.abs_action:
            self.actions = np.array(replay_buffer['abs_action'])
        else:
            self.actions = np.array(replay_buffer['action'])

        self.action_mean = np.mean(self.actions, axis=0)  # shape: (action_dim)
        self.action_std = np.std(self.actions, axis=0)
        self.action_max = np.max(self.actions, axis=0)
        self.action_min = np.min(self.actions, axis=0)

        # self.low_dim_states = zarr_data_np['data']['state']
        self.num_hist = num_hist
        self.num_pred = num_pred
        self.frameskip = frameskip
        self.num_frames = num_hist + num_pred
        self.action_conditioned_time_contrastive = action_conditioned_time_contrastive
        self.use_crop = use_crop
        self.train = train
        self.encoder_type = encoder_type
        
        print('episode_ends ', self.episode_ends)
        print('action shape ', self.actions.shape)
        # Convert episode_ends to zero-indexed format and store the start and end indices of each trajectory
        self.episode_start_indices = np.concatenate(([0], self.episode_ends[:-1]))
        self.episode_end_indices = self.episode_ends - 1  # last index of each trajectory
        
        # Precompute valid anchor indices
        self.valid_anchor_indices = []
        for start, end in zip(self.episode_start_indices, self.episode_end_indices):
            # Valid anchors are from start + horizon_history to end - horizon_future
            anchor_start = start
            anchor_end = end - num_pred * self.frameskip
            if anchor_end >= anchor_start:
                anchors = np.arange(anchor_start, anchor_end)
                self.valid_anchor_indices.extend(anchors)
        self.valid_anchor_indices = np.array(self.valid_anchor_indices)
        self.num_valid = len(self.valid_anchor_indices)
        self.transform = default_transform()
        if self.use_crop:
            print('using crop! ')
            if self.train:
                print('using train random crop! ')
                if encoder_type == 'resnet':
                    print('using resnet crop! ')
                    self.transform = get_train_crop_transform_resnet(original_img_size)
                else:
                    self.transform = get_train_crop_transform(original_img_size)
            else:
                print('using eval enter crop! ')
                if encoder_type == 'resnet':
                    print('using resnet crop! ')
                    self.transform = get_eval_crop_transform_resnet(original_img_size)
                else:
                    self.transform = get_eval_crop_transform(original_img_size)
        print('len ', self.num_valid)
        

    def __len__(self):
        """
        Returns the number of valid anchor samples.
        """
        return self.num_valid
    
    def __getitem__(self, idx):
        start = self.valid_anchor_indices[idx]
        end = start + (self.num_frames) * self.frameskip
        obs_indices = list(range(start, end, self.frameskip))
        action_indices = list(range(start, end))
        action_indices[-self.frameskip:] = [obs_indices[-1] - 1] * self.frameskip
        obs = {}
        obs['visual'] = {}
        for view_name in self.view_names:
            obs['visual'][view_name] = self.imgs[view_name][obs_indices]
            obs['visual'][view_name] = np.moveaxis(obs['visual'][view_name],-1,1).astype(np.float32)/255
            obs['visual'][view_name] = torch.from_numpy(obs['visual'][view_name])
            
            # if self.encoder_type == 'resnet':
            #     obs['visual'][view_name] = self.normalizer[view_name].normalize(obs['visual'][view_name])
            # obs['visual'][view_name] = self.transform(obs['visual'][view_name])
        obs['proprio'] = self.states[obs_indices]
        obs['proprio'] = torch.from_numpy(obs['proprio'].astype(np.float32))
        act = self.actions[action_indices]
        state = self.states[obs_indices]
        act = torch.from_numpy(act.astype(np.float32))
        state = torch.from_numpy(state.astype(np.float32))

        rewards = np.zeros((2,))
        if self.return_rewards:
            rewards = self.rewards[obs_indices]
            rewards = torch.from_numpy(rewards.astype(np.float32))
        # neg_imgs = np.zeros((5,))
        multi_step_pos_imgs = np.zeros((3,))
        if self.action_conditioned_time_contrastive:
            # episode_idx = np.searchsorted(self.episode_end_indices, start, side='right')
            # episode_start = self.episode_start_indices[episode_idx]
            # negative_candidates = []
            # for epi_start, epi_end in zip(self.episode_start_indices, self.episode_end_indices):
            #     if epi_start != episode_start:
            #         negative_candidates = np.concatenate((negative_candidates, np.arange(epi_start, epi_end)))
            # negative_candidates = negative_candidates.astype(int)
            # negative_t = np.random.choice(negative_candidates, size=5, replace=False)
            # neg_imgs = {}
            # for view_name in self.view_names:
            #     neg_imgs[view_name] = self.transform(torch.tensor(np.moveaxis(self.imgs[view_name][negative_t],-1,1)/255, dtype=torch.float32))
            # # ### sanity check ###
            # # for id in negative_t:
            # #     neg_epi_idx = np.searchsorted(self.episode_end_indices, id, side='right')
            # #     neg_epi_start = self.episode_start_indices[neg_epi_idx]
            # #     if episode_start == neg_epi_start:
            # #         raise ValueError('Negative samples are from the same trajectory')
            # # ### end of sanity check ###

            multi_step_pos_indices = list(range(end - 1 * self.frameskip, end + 0 * self.frameskip, self.frameskip))
            multi_step_pos_imgs = {}
            for view_name in self.view_names:
                multi_step_pos_imgs[view_name] = torch.from_numpy(np.moveaxis(self.imgs[view_name][multi_step_pos_indices],-1,1).astype(np.float32)/255)

                # if self.encoder_type == 'resnet':
                #     multi_step_pos_imgs[view_name] = self.normalizer[view_name].normalize(multi_step_pos_imgs[view_name])
                # multi_step_pos_imgs[view_name] = self.transform(multi_step_pos_imgs[view_name])
                # print('multi_step_pos_imgs[view_name] ', multi_step_pos_imgs[view_name].shape)
            # #### sanity check ####
            # assert torch.allclose(multi_step_pos_imgs['agentview'][0], obs['visual']['agentview'][-1])
        return tuple([obs, act, state, multi_step_pos_imgs, rewards])

    def get_normalizer(self, **kwargs) -> LinearNormalizer:
        normalizer = LinearNormalizer()
        # action
        act_stat = array_to_stats(self.actions)

        if self.abs_action:
            act_normalizer = robomimic_abs_action_only_normalizer_from_stat(act_stat)
        else:
            # already normalized
            act_normalizer = get_identity_normalizer_from_stat(act_stat)
        normalizer['act'] = act_normalizer
        # state
        state_stat = array_to_stats(self.states)
        normalizer['state'] = get_range_normalizer_from_stat(state_stat)
        if self.encoder_type == 'resnet':
            for view_name in self.view_names:
                normalizer[view_name] = get_image_range_normalizer()
        return normalizer

