from typing import Dict
import torch
import numpy as np
import copy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.common.sampler import (
    SequenceSampler, get_val_mask, downsample_mask)
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.common.normalize_util import get_image_range_normalizer
import zarr


class PushTImageDataset(BaseImageDataset):
    def __init__(self,
            zarr_path, 
            horizon=1,
            pad_before=0,
            pad_after=0,
            seed=42,
            val_ratio=0.0,
            max_train_episodes=None
            ):
        
        super().__init__()
        # self.replay_buffer = ReplayBuffer.copy_from_path(
        #     zarr_path, keys=['img', 'state', 'action'])
        zarr_data = zarr.open(zarr_path, mode='r')
        self.replay_buffer = {
            'img': zarr_data['data']['img'],
            'state': zarr_data['data']['state'],
            'action': zarr_data['data']['action'],
        }
        self.episode_ends = zarr_data['meta']['episode_ends'][:]
        n_episodes = len(self.episode_ends)
        val_mask = get_val_mask(
            n_episodes=n_episodes, 
            val_ratio=val_ratio,
            seed=seed)
        train_mask = ~val_mask
        train_mask = downsample_mask(
            mask=train_mask, 
            max_n=max_train_episodes, 
            seed=seed)

        self.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            episode_ends=self.episode_ends,
            sequence_length=horizon,
            pad_before=pad_before, 
            pad_after=pad_after,
            episode_mask=train_mask)
        self.train_mask = train_mask
        self.horizon = horizon
        self.pad_before = pad_before
        self.pad_after = pad_after

    def get_validation_dataset(self):
        val_set = copy.copy(self)
        val_set.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer, 
            episode_ends=self.episode_ends,
            sequence_length=self.horizon,
            pad_before=self.pad_before, 
            pad_after=self.pad_after,
            episode_mask=~self.train_mask
            )
        val_set.train_mask = ~self.train_mask
        return val_set

    def get_normalizer(self, mode='limits', **kwargs):
        data = {
            'action': self.replay_buffer['action'],
            'agent_pos': self.replay_buffer['state'][...,:2]
        }
        normalizer = LinearNormalizer()
        normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
        normalizer['image'] = get_image_range_normalizer()
        return normalizer

    def __len__(self) -> int:
        return len(self.sampler)

    def _sample_to_data(self, sample):
        agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3)
        image = np.moveaxis(sample['img'],-1,1)/255

        data = {
            'obs': {
                'image': image, # T, 3, 96, 96
                'agent_pos': agent_pos, # T, 2
            },
            'action': sample['action'].astype(np.float32) # T, 2
        }
        # image_all_np = self.replay_buffer.data['img']
        # agent_pos_np = self.replay_buffer.data['state'][:,:2]
        # num_eps = 0
        # for i, (image_iter_np, agent_pos_iter) in enumerate(zip(image_all_np, agent_pos_np)):
        #     import matplotlib.pyplot as plt
        #     image_iter_np = image_iter_np/255
        #     if abs(agent_pos_np[i+1][0] - agent_pos_np[i][0])>29:
        #         num_eps +=1
        #         # image_iter_np = np.transpose(image_iter_np, (1, 2, 0))
        #         # 显示图像
        #         print(f"{i}: {agent_pos_np[i+1][0] - agent_pos_np[i][0]}")
        #         plt.imshow(image_iter_np)
        #         plt.axis('off')  # 关闭坐标轴显示
        #         plt.show()
        return data
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # start_id = max(idx-128,0)
        # import cv2
        # for i in range(128):
        #     img = self.replay_buffer.data['img'][start_id+i, :]
        #     state = self.replay_buffer.data['state'][start_id+i, :-1].reshape(1, 2)
        #     coords_action = self.replay_buffer.data['action'][start_id+i, :].reshape(1, 2)
        #     scale_x = 96 / 500
        #     scale_y = 96 / 500
        #     scaled_state = state * [scale_x, scale_y]
        #     scaled_coords_action = coords_action * [scale_x, scale_y]
        #     for state in scaled_state:
        #         cv2.circle(img, tuple(state.astype(int)), 2, (0, 0, 255), -1)
        #     for coord in scaled_coords_action:
        #         cv2.circle(img, tuple(coord.astype(int)), 2, (255, 0, 0), -1)
        #     cv2.imwrite(f'{i}.jpg', img)
        sample = self.sampler.sample_sequence(idx)
        data = self._sample_to_data(sample)
        torch_data = dict_apply(data, torch.from_numpy)
        torch_data['index'] = idx
        return torch_data


def test():
    import os
    zarr_path = os.path.expanduser('~/dev/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr')
    dataset = PushTImageDataset(zarr_path, horizon=16)

    # from matplotlib import pyplot as plt
    # normalizer = dataset.get_normalizer()
    # nactions = normalizer['action'].normalize(dataset.replay_buffer['action'])
    # diff = np.diff(nactions, axis=0)
    # dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)
