from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.dataset.robomimic_replay_image_dataset import _convert_robomimic_to_replay
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
import torch
from filelock import FileLock
import os
from torch.utils.data import Dataset
import numpy as np
import copy
import shutil

import zarr
import random
from torch.utils.data import DataLoader

from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.common.normalize_util import array_to_stats, get_image_range_normalizer, get_range_normalizer_from_stat, robomimic_abs_action_only_normalizer_from_stat
from diffusion_policy.common.sampler import (
    SequenceSampler, get_val_mask, downsample_mask)


class RobomimicMultiStepImageDynamicsModelDataset(Dataset):
    def __init__(self, zarr_path, horizon=1, val_ratio=0.0, n_neg=5, action_dim=10):
        # Load zarr data
        zarr_data = zarr.open(zarr_path, mode='r')
        zarr_data_np = {
            'data': {key: np.array(zarr_data['data'][key]) for key in zarr_data['data'].keys()},
            'meta': {key: np.array(zarr_data['meta'][key]) for key in zarr_data['meta'].keys()}
        }
        
        # Extract episode ends (0-indexed)
        episode_ends = zarr_data_np['meta']['episode_ends']  # e.g., [60, 110, ...]

        # Extract and process states and actions
        self.imgs = zarr_data_np['data']['agentview_image']
        print('self.imgs ', self.imgs.shape)
        # self.imgs = self.imgs[:, :224, 16:, :]
        self.states = zarr_data_np['data']['robot0_eef_pos']  # Shape: (total_timesteps, state_dim)
        if action_dim == 7:
            self.actions = zarr_data_np['data']['raw_action'].astype(np.float32)  # Shape: (total_timesteps, action_dim)\
        elif action_dim == 10:
            self.actions = zarr_data_np['data']['converted_action']
        self.n_contacts = None

        self.horizon = horizon
        self.val_ratio = val_ratio

        val_mask = get_val_mask(
            n_episodes=len(episode_ends),
            val_ratio=val_ratio,
            seed=42,
            random=False)
        self.train_mask = ~val_mask
        print('self.train_mask ', self.train_mask)
        self.original_episode_ends = episode_ends
        self.episode_ends = episode_ends[self.train_mask]
        # print('episode_ends ', self.episode_ends)
        # Process episode ends to get trajectory boundaries
        self.trajectories = []  # List of tuples: (start_idx, end_idx)

        prev_end = 0

        if self.train_mask[0]:
            prev_end = 0
        else:
            tmp = np.where(self.train_mask == True)[0][0]
            # print('tmp ', tmp)
            prev_end = episode_ends[tmp - 1]

        for end in self.episode_ends:
            traj_end = end
            self.trajectories.append((prev_end, traj_end))
            prev_end = end
        print('self.trajectories ', self.trajectories)
        # Precompute sample indices as (traj_idx, t) tuples
        self.sample_indices = []
        for traj_idx, (start, end) in enumerate(self.trajectories):
            traj_length = end - start - 6
            for t in range(traj_length):
                self.sample_indices.append((traj_idx, t))
        # print('sample_indices ', self.sample_indices)

        self.n_neg = n_neg

    def __len__(self):
        # Return the number of valid transitions
        return len(self.sample_indices)

    def __getitem__(self, idx):
        traj_idx, t = self.sample_indices[idx]
        # print('traj_idx', traj_idx)

        # Get trajectory boundaries
        start, end = self.trajectories[traj_idx]
        # print('traj start', start)

        # Current state index
        s_t_idx = start + t  # Current state index
        s_t = self.states[s_t_idx]  # Shape: (state_dim +1,)
        img_t = self.imgs[s_t_idx]  # Shape: (96, 96, 3)
        # Actions indices: a_t to a_{t + horizon -1}
        a_start = start + t
        a_end = a_start + self.horizon  # Exclusive

        # Future states indices: s_{t +1} to s_{t + horizon}
        s_future_start = s_t_idx + 1
        s_future_end = s_future_start + self.horizon  # Exclusive

        # Initialize placeholders for actions and future states
        actions_seq = []
        s_future_seq = []
        img_future_seq = []

        # Handle actions sequence with padding if necessary
        if a_end <= end:
            # Sufficient actions available
            actions_seq = self.actions[a_start:a_end]
            # print('actions_seq ', actions_seq.shape)
        else:
            # Not enough actions; pad with the last action
            available_actions = end - a_start
            if available_actions > 0:
                actions_seq = self.actions[a_start:end]
                pad_length = self.horizon - available_actions
                last_action = self.actions[end -1]
                pad_actions = np.tile(last_action, (pad_length, 1))
                actions_seq = np.vstack((actions_seq, pad_actions))
            else:
                # No actions available; repeat the last action
                last_action = self.actions[end -1]
                actions_seq = np.tile(last_action, (self.horizon, 1))

        # Handle future states sequence with padding if necessary
        if s_future_end <= end:
            # Sufficient future states available
            s_future_seq = self.states[s_future_start:s_future_end]
            img_future_seq = self.imgs[s_future_start:s_future_end]

        else:
            # Not enough future states; pad with the last state
            available_states = end - s_future_start
            if available_states > 0:
                s_future_seq = self.states[s_future_start:end]
                img_future_seq = self.imgs[s_future_start:end]
                pad_length = self.horizon - available_states
                last_state = self.states[end -1]

                pad_states = np.tile(last_state, (pad_length, 1))
                pad_imgs = np.tile(self.imgs[end -1], (pad_length, 1, 1, 1))
                s_future_seq = np.vstack((s_future_seq, pad_states))
                img_future_seq = np.vstack((img_future_seq, pad_imgs))

            else:
                # No future states available; repeat the last state
                last_state = self.states[end -1]
                s_future_seq = np.tile(last_state, (self.horizon, 1))
                last_img = self.imgs[end -1]
                img_future_seq = np.tile(last_img, (self.horizon, 1, 1, 1))

        # Convert to torch tensors
        s_t = torch.tensor(s_t, dtype=torch.float32)  # Shape: (state_dim +1,)
        actions_seq = torch.tensor(actions_seq, dtype=torch.float32)  # Shape: (horizon, action_dim)
        s_future_seq = torch.tensor(s_future_seq, dtype=torch.float32)  # Shape: (horizon, state_dim +1)
        img_t = np.moveaxis(img_t,-1,0)/255
        img_t = torch.tensor(img_t, dtype=torch.float32)  # Shape: (96, 96, 3)
        img_future_seq = np.moveaxis(img_future_seq,-1,1)/255
        img_future_seq = torch.tensor(img_future_seq, dtype=torch.float32)  # Shape: (horizon, 96, 96, 3)

        data = {
            's_t': s_t,
            'img_t': img_t,
            'actions_seq': actions_seq,
            's_future_seq': s_future_seq,
            'img_future_seq': img_future_seq
        }

        # negative sampling
        negative_candidates = []
        for i in range(len(self.trajectories)):
            if self.trajectories[i][0] != start:
                # print('self.trajectories[{}][0] {} start {}'.format(i, self.trajectories[i][0], start))
                negative_candidates = np.concatenate((negative_candidates, np.arange(self.trajectories[i][0], self.trajectories[i][1])))
        negative_candidates = negative_candidates.astype(int)
        negative_t = np.random.choice(negative_candidates, size=self.n_neg, replace=False)

        # if s_t_idx < end - 50:
        #     p = np.random.rand()
        #     if p < 0.05:
        #         # print('this branch')
        #         negative_t[-1] = end - 1

        # print("negative_candidates ", negative_candidates)
        # ### sanity check ###
        # cur_start = np.searchsorted(self.original_episode_ends, s_t_idx, side='right')
        # for id in negative_t:
        #     # print('id', id)
        #     # print('neg state traj id {} t {}'.format(self.sample_indices[id][0], self.sample_indices[id][1]))
        #     neg_start = np.searchsorted(self.original_episode_ends, id, side='right')
        #     print('cur_start {} neg_start {} '.format(cur_start, neg_start))
        #     if cur_start == neg_start:
        #         raise ValueError('Negative samples are from the same trajectory')
        # ### sanity check end ###
        # print('negative_t   ', negative_t)
        # print('img future index ', np.arange(s_future_start, s_future_end))
        negative_imgs = self.imgs[negative_t]
        negative_imgs = np.moveaxis(negative_imgs,-1,1)/255
        data['negative_imgs'] = torch.tensor(negative_imgs, dtype=torch.float32)

        return data

    def get_normalizer(self, **kwargs) -> LinearNormalizer:
        normalizer = LinearNormalizer()

        # action
        action_stat = array_to_stats(self.actions)
        normalizer['action'] = robomimic_abs_action_only_normalizer_from_stat(action_stat)

        # state
        state_stat = array_to_stats(self.states)
        normalizer['state'] = get_range_normalizer_from_stat(state_stat)

        # image
        normalizer['image'] = get_image_range_normalizer()
        return normalizer
    

class RobomimicMultiStepWithHistoryImageDynamicsModelDataset(Dataset):
    def __init__(self, 
                 zarr_path, 
                 horizon_history=1, 
                 horizon_future=8, 
                 skip_ratio=4, 
                 val_ratio=0.0, 
                 n_neg=5, 
                 neg_sampling='other_ep_as_neg', 
                 patch=False,
                 view_names=['agentview', 'robot0_eye_in_hand']):
        """
        Initializes the dataset by loading data from a Zarr file and precomputing valid anchor indices.
        
        Args:
            zarr_path (str): Path to the Zarr dataset.
            horizon (int): Number of steps for history and future.
            val_ratio (float): Fraction of episodes to use for validation.
            n_neg (int): Number of negative samples (unused in this implementation).
        """
        # Load zarr data
        # with zarr.ZipStore(zarr_path, mode='r') as zip_store:
        #     zarr_data = ReplayBuffer.copy_from_store(
        #         src_store=zip_store, store=zarr.MemoryStore())
        zarr_data = zarr.open(zarr_path, mode='r')
        zarr_data_np = {
            'data': {key: np.array(zarr_data['data'][key]) for key in zarr_data['data'].keys()},
            'meta': {key: np.array(zarr_data['meta'][key]) for key in zarr_data['meta'].keys()}
        }
        print('Loaded!')
        # Extract episode ends (1-indexed)
        episode_ends = zarr_data_np['meta']['episode_ends']  # e.g., [60, 110, ...]
        n_episodes = len(episode_ends)
        
        # Extract and process states and actions
        self.view_names = view_names
        # self.imgs = zarr_data_np['data']['agentview_image']
        self.imgs = {}
        print(zarr_data_np['data'].keys())
        for view_name in self.view_names:
            self.imgs[view_name] = zarr_data_np['data'][f'{view_name}_image']
            print(f'self.imgs[{view_name}] ', self.imgs[view_name].shape)

        # NOTE: temporary workaround, later need to generate a new dataset with 224x224 images
        if patch:
            for view_name in self.view_names:
                self.imgs[view_name] = self.imgs[view_name][:, :224, 16:, :]
            # self.imgs = self.imgs[:, :224, 16:, :]
        self.states = zarr_data_np['data']['robot0_eef_pos']  # Shape: (total_timesteps, state_dim)
        self.actions = zarr_data_np['data']['converted_action']
        self.horizon_history = horizon_history
        self.horizon_future = horizon_future
        self.k = skip_ratio
        self.val_ratio = val_ratio
        self.n_neg = n_neg  # Currently unused
        self.neg_sampling = neg_sampling
        
        # Generate validation mask
        val_mask = get_val_mask(
            n_episodes=n_episodes,
            val_ratio=val_ratio,
            seed=42,
            random=False)
        self.train_mask = ~val_mask
        self.episode_ends_train = episode_ends[self.train_mask]
        print('episode_ends_train ', self.episode_ends_train)
        # Convert episode_ends to zero-indexed format and store the start and end indices of each trajectory
        self.episode_start_indices = np.concatenate(([0], self.episode_ends_train[:-1]))
        self.episode_end_indices = self.episode_ends_train - 1  # last index of each trajectory
        
        # Precompute valid anchor indices
        self.valid_anchor_indices = []
        for start, end in zip(self.episode_start_indices, self.episode_end_indices):
            # Valid anchors are from start + horizon_history to end - horizon_future
            anchor_start = start + self.horizon_history
            anchor_end = end - self.horizon_future
            if anchor_end >= anchor_start:
                anchors = np.arange(anchor_start, anchor_end + 1)
                self.valid_anchor_indices.extend(anchors)
        self.valid_anchor_indices = np.array(self.valid_anchor_indices)
        # print('valid_anchor_indices ', self.valid_anchor_indices)
        self.num_valid = len(self.valid_anchor_indices)
    
    def __len__(self):
        """
        Returns the number of valid anchor samples.
        """
        return self.num_valid
    
    def __getitem__(self, idx):
        """
        Retrieves the sample at the given index.
        
        Args:
            idx (int): Index of the sample to retrieve.
        
        Returns:
            dict: A dictionary containing grouped history and future states/actions.
        """
        if idx < 0 or idx >= self.num_valid:
            raise IndexError(f"Index {idx} out of range for dataset of size {self.num_valid}")
        anchor_idx = self.valid_anchor_indices[idx]
        episode_idx = np.searchsorted(self.episode_end_indices, anchor_idx, side='right')
        episode_start = self.episode_start_indices[episode_idx]
        episode_end = self.episode_end_indices[episode_idx]

        history_imgs_indices = np.arange(anchor_idx - self.horizon_history, anchor_idx + 1, step=self.k)
        history_actions_indices = np.arange(anchor_idx - self.horizon_history, anchor_idx)
        future_actions_indices = np.arange(anchor_idx, anchor_idx + self.horizon_future)
        future_imgs_indices = np.arange(anchor_idx + self.k, anchor_idx + self.horizon_future + 1, step=self.k)

        # Extract data
        history_actions = self.actions[history_actions_indices]
        future_actions = self.actions[future_actions_indices]

        history_imgs = {}
        for view_name in self.view_names:
            history_imgs[view_name] = self.imgs[view_name][history_imgs_indices]
            history_imgs[view_name] = torch.tensor(np.moveaxis(history_imgs[view_name],-1,1)/255, dtype=torch.float32)

        future_imgs = {}
        for view_name in self.view_names:
            future_imgs[view_name] = self.imgs[view_name][future_imgs_indices]
            future_imgs[view_name] = torch.tensor(np.moveaxis(future_imgs[view_name],-1,1)/255, dtype=torch.float32)
        
        # print()
        # print('anchor_idx ', anchor_idx)
        # print('history_imgs_indices ', history_imgs_indices)
        # print('history_actions_indices ', history_actions_indices)
        # print('future_actions_indices ', future_actions_indices)
        # print('future_imgs_indices ', future_imgs_indices)
        # assert (len(history_imgs) == self.horizon_history // self.k + 1 and 
        #         len(history_actions) == self.horizon_history and 
        #         len(future_actions) == self.horizon_future and 
        #         len(future_imgs) == self.horizon_future // self.k)
        # assert (history_imgs_indices[0] >= episode_start and history_imgs_indices[-1] <= episode_end)
        # if len(history_actions_indices) > 0:
        #     assert (history_actions_indices[0] >= episode_start and history_actions_indices[-1] <= episode_end)
        # assert (future_actions_indices[0] >= episode_start and future_actions_indices[-1] <= episode_end)
        # assert (future_imgs_indices[0] >= episode_start and future_imgs_indices[-1] <= episode_end)

        # Convert all to tensors
        history_actions = torch.tensor(history_actions, dtype=torch.float32)
        future_actions = torch.tensor(future_actions, dtype=torch.float32)

        if idx == 100:
            print('history_actions ', history_actions)
            print('future_actions ', future_actions)
            print('future_imgs ', future_imgs['agentview'][0][0][:100])
            
        # Negative sampling
        negative_candidates = []
        future_not_enough = False

        if self.neg_sampling == 'future_as_neg':
            neg_start = anchor_idx + 25
            if neg_start + self.n_neg >= episode_end:
                future_not_enough = True
            else:
                negative_candidates = np.arange(neg_start, episode_end)
        if self.neg_sampling == 'other_ep_as_neg' or future_not_enough:
            for start, end in zip(self.episode_start_indices, self.episode_end_indices):
                if start != episode_start:
                    # sampling heuristic
                    negative_candidates = np.concatenate((negative_candidates, np.arange(start, end)))

        negative_candidates = negative_candidates.astype(int)
        negative_t = np.random.choice(negative_candidates, size=self.n_neg, replace=False)
        # ### sanity check ###
        # for id in negative_t:
        #     # print('id', id)
        #     # print('neg state traj id {} t {}'.format(self.sample_indices[id][0], self.sample_indices[id][1]))
        #     neg_epi_idx = np.searchsorted(self.episode_end_indices, id, side='right')
        #     neg_epi_start = self.episode_start_indices[neg_epi_idx]
        #     if episode_start == neg_epi_start:
        #         raise ValueError('Negative samples are from the same trajectory')
        # # print('negative_t   ', negative_t)
        # # print('img future index ', np.arange(s_future_start, s_future_end))
        # ### end of sanity check ###
        neg_imgs = {}
        for view_name in self.view_names:
            neg_imgs[view_name] = torch.tensor(np.moveaxis(self.imgs[view_name][negative_t],-1,1)/255, dtype=torch.float32)

        return {
            'history_imgs': history_imgs,       # Shape: (horizon_history, C, H, W)
            'history_actions': history_actions,     # Shape: (horizon_history, action_dim)
            'future_actions': future_actions,       # Shape: (horizon_future - 1, action_dim)
            'future_imgs': future_imgs,         # Shape: (horizon_future, C, H, W)
            'negative_imgs': neg_imgs,          # Shape: (n_neg, C, H, W)
        }

    def get_normalizer(self, **kwargs) -> LinearNormalizer:
        normalizer = LinearNormalizer()

        # action
        action_stat = array_to_stats(self.actions)
        normalizer['action'] = robomimic_abs_action_only_normalizer_from_stat(action_stat)

        # state
        state_stat = array_to_stats(self.states)
        normalizer['state'] = get_range_normalizer_from_stat(state_stat)

        # image
        normalizer['image'] = get_image_range_normalizer()
        return normalizer


if __name__ == "__main__":
    # Example usage
    zarr_path = 'data/square_exp_data/1737012334_462362_demo_image_abs_2_views_240.zarr'
    # zarr_path = 'data/square_exp_data/square_play_val_image_abs_v2_240.zarr',

    dataset = RobomimicMultiStepWithHistoryImageDynamicsModelDataset(
        zarr_path=zarr_path,
        horizon_history=4,
        horizon_future=8,
        skip_ratio=4,
        val_ratio=0.0,
        n_neg=5,
    )

    normalizer = dataset.get_normalizer()
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    for batch in dataloader:
        future_imgs = batch['future_imgs']['agentview']
        history_actions = batch['history_actions']
        future_actions = batch['future_actions']
        # print('==================================================')
        # if isinstance(batch, dict):
        #     for key, value in batch.items():
        #         if isinstance(value, dict):
        #             for key2, value2 in value.items():
        #                 print(f"{key}/{key2}: {value2.shape}")
        #         else:
        #             print(f"{key}: {value.shape}")
        # future_imgs = batch['future_imgs']
        # # future_imgs = normalizer['image'].normalize(future_imgs)
        # for key, value in future_imgs.items():
        #     future_imgs[key] = normalizer['image'].normalize(value)
        #     print(f"{key}: {value.shape}")
        #     print('range ', value.min(), value.max())