import os
import os.path as osp
import torch
import numpy as np
import io
import json
import mmengine.fileio as fileio
from PIL import Image
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, DistributedSampler
import albumentations as A
from albumentations.pytorch import ToTensorV2
import h5py
import cv2
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn as nn

from collections import deque
import utils.utils as utils

def build_base_transform(n_px, aug=True, to_tensor=True, apply_norm=True,
                        crop_scale=(0.75,1.0), crop_ratio=(0.75, 1.33), crop_prob=1.0,
                        norm_mean = (0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
    base_transform = []

    if aug:
        base_transform.append(A.RandomResizedCrop(height=n_px, width=n_px, p=crop_prob,
                                                  scale=crop_scale, ratio=crop_ratio))
    else :
        base_transform.append(A.Resize(height=n_px, width=n_px))

    if apply_norm:
        base_transform.append(A.Normalize(mean=norm_mean, std=norm_std, max_pixel_value=255.0, p=1.0))

    if to_tensor:
        base_transform.append(ToTensorV2())

    base_transform = A.ReplayCompose(base_transform)
    return base_transform

def build_dataset_statistics(dataset_path, cache_json_name='cache.json',stage=0, few_num=5):
    if stage == 3 or stage == 4:
        print('libero_few cache')
        cache_json = osp.join(dataset_path, 'cache_few.json')
    else:
        print('libero cache')
        cache_json = osp.join(dataset_path, cache_json_name)
    if osp.isfile(cache_json):
        print('dataset statistics exits')
        dataset_statistics = json.load(open(cache_json, 'r'))
    else :
        hdf5_files = []
        print('Beginning to build dataset statistics...')
        if stage==3 or stage == 4:
            data_path = Path(dataset_path)

            for folder in data_path.iterdir():
                if folder.is_dir():
                    all_hdf5_files = [str(file.resolve()) for file in folder.rglob('*.hdf5')]

                    if len(all_hdf5_files) > few_num:
                        selected_files = random.sample(all_hdf5_files, few_num)
                    else:
                        selected_files = all_hdf5_files

                    hdf5_files.extend(selected_files)
        else:
            hdf5_files = [str(file.resolve()) for file in Path(dataset_path).rglob('*.hdf5')]

        views = ['images0']
        traj_lens = []
        proprios = []
        actions = []
        # check all data
        for file in tqdm(hdf5_files):
            with h5py.File(io.BytesIO(fileio.get(file)), 'r') as f:
                views = list(f['observation'].keys())
                traj_actions = f['action'][()].astype('float32')
                traj_proprios = f['proprio'][()].astype('float32')
                actions.append(traj_actions)
                proprios.append(traj_proprios)
                traj_lens.append(traj_actions.shape[0])
        # calculate statistics
        actions = np.concatenate(actions, axis=0)
        proprios = np.concatenate(proprios, axis=0)
        action_max = actions.max(axis=0).tolist()
        action_min = actions.min(axis=0).tolist()
        proprio_max = proprios.max(axis=0).tolist()
        proprio_min = proprios.min(axis=0).tolist()
        dataset_statistics = dict(views=views, action_max=action_max, action_min=action_min,
                                  proprio_max = proprio_max, proprio_min = proprio_min,
                                  traj_paths=hdf5_files, traj_lens=traj_lens)
        with open(cache_json, 'w') as f:
            json.dump(dataset_statistics, f, indent=4)
    return dataset_statistics

class LiberoProcessor(object):
    def __init__(self, dataset_path, img_size=224, training=True, stage=0, few_num=5):
        self.img_transform = build_base_transform(n_px=img_size, aug=training)
        dataset_statistics = build_dataset_statistics(dataset_path,stage=stage, few_num=few_num)
        self.action_max = np.array(dataset_statistics['action_max'])
        self.action_min = np.array(dataset_statistics['action_min'])
        self.proprio_max = np.array(dataset_statistics['proprio_max'])
        self.proprio_min = np.array(dataset_statistics['proprio_min'])
        # fix parameters
        self.action_length = 7
        self.proprio_length = 9
    
    def preprocess_image(self, img, replay_params=None):
        if replay_params == None:
            transformed = self.img_transform(image=img)
            transformed_image = transformed['image']
            replay_params = transformed['replay']
        else :
            transformed = A.ReplayCompose.replay(replay_params, image=img)
            transformed_image = transformed['image']
        return transformed_image, replay_params
    
    def preprocess_action(self, action):
        action = (action - self.action_min) / (self.action_max - self.action_min) * 2 - 1
        action = torch.flatten(torch.from_numpy(action))
        return action
    
    def preprocess_proprio(self, proprio):
        proprio = (proprio - self.proprio_min) / (self.proprio_max - self.proprio_min) * 2 - 1
        proprio = torch.flatten(torch.from_numpy(proprio))
        return proprio
    
    def postprocess_action(self, tensor_flatten_action):
        # action B 42 -> B 6 7
        B, _ = tensor_flatten_action.shape
        action = tensor_flatten_action.reshape(B, -1, self.action_length)
        action[..., -1] = torch.sign(action[..., -1])
        action = (action + 1) / 2 * (self.action_max - self.action_min) + self.action_min
        return action.numpy()


class LiberoDataset(Dataset):
    def __init__(self, dataset_path, processor, chunk_length=6, 
                 recursive_step=4, rec_plan_coef=0.5, stage=0, few_num=5):
        self.processor = processor
        self.dataset_path = dataset_path
        self.chunk_length = chunk_length
        self.recursive_step = recursive_step
        self.rec_plan_coef = rec_plan_coef
        self.stage = stage
        self.few_num = few_num
        self._load_metas()
    
    def _load_metas(self):
        dataset_statistics = build_dataset_statistics(self.dataset_path,stage=self.stage, few_num=self.few_num)
        traj_paths = dataset_statistics['traj_paths']
        traj_lens = dataset_statistics['traj_lens']
        self.views = dataset_statistics['views']
        self.main_view = self.views[0]
        self.metas = []
        for i in range(len(traj_paths)):
            self.metas.extend([(traj_paths[i], j, traj_lens[i]-1) for j in range(traj_lens[i])])

    def _load_from_raw_traj(self, traj_path, cur_idx, goal_idx):
        # with h5py.File(io.BytesIO(fileio.get(traj_path)), 'r') as f:
        with h5py.File(traj_path, 'r') as f:
            if self.stage == 0 or self.stage == 3:
                # load actions with chunking
                np_action = f['action'][()][cur_idx: cur_idx + self.chunk_length]
                if len(np_action) < self.chunk_length:
                    cnt = self.chunk_length - len(np_action)
                    padding = np.array([[0., 0., 0., 0., 0., 0., np_action[-1][-1]]]).repeat(cnt, axis=0)
                    np_action = np.concatenate([np_action, padding], axis=0)

                return np_action

            elif self.stage == 1  or self.stage == 4:
                # load images from all views
                raw_images = []
                for view in self.views:
                    raw_img = cv2.imdecode(f['observation'][view][cur_idx], cv2.IMREAD_COLOR)
                    raw_images.append(raw_img)
                # load actions with chunking
                np_action = f['action'][()][cur_idx: cur_idx + self.chunk_length]
                if len(np_action) < self.chunk_length:
                    cnt = self.chunk_length - len(np_action)
                    padding = np.array([[0., 0., 0., 0., 0., 0., np_action[-1][-1]]]).repeat(cnt, axis=0)
                    np_action = np.concatenate([np_action, padding], axis=0)
                # load proprio
                raw_proprio = f['proprio'][()][cur_idx]
                # load instruction
                instruction = f['language_instruction'][()].decode('utf-8')

                return raw_images, np_action, raw_proprio, instruction

            # load images from all views
            raw_images = []
            for view in self.views:
                raw_img = cv2.imdecode(f['observation'][view][cur_idx], cv2.IMREAD_COLOR)
                raw_images.append(raw_img)
            # load subgoals
            subgoals = []
            for i in range(self.recursive_step):
                raw_img = cv2.imdecode(f['observation'][self.main_view][goal_idx], cv2.IMREAD_COLOR)
                goal_idx = cur_idx + int((goal_idx - cur_idx) * self.rec_plan_coef)
                subgoals.append(raw_img)
            # load actions with chunking
            np_action = f['action'][()][cur_idx : cur_idx + self.chunk_length]
            if len(np_action) < self.chunk_length:
                cnt = self.chunk_length - len(np_action)
                padding = np.array([[0., 0., 0., 0., 0., 0., np_action[-1][-1]]]).repeat(cnt, axis=0)
                np_action = np.concatenate([np_action, padding], axis=0)
            # load proprio
            raw_proprio = f['proprio'][()][cur_idx]
            # load instruction
            instruction = f['language_instruction'][()].decode('utf-8')
        return raw_images, subgoals, np_action, raw_proprio, instruction

    def __len__(self):
        return len(self.metas) 
    
    def __getitem__(self, index):
        meta = self.metas[index]

        if self.stage == 0 or self.stage == 3:
            np_action = self._load_from_raw_traj(meta[0], meta[1],meta[2])
            final_action = self.processor.preprocess_action(np_action)  # 42
            item = {
                'cur_actions': final_action,
                'traj_path': meta[0],
                'cur_idx': meta[1],
            }
            return item
        elif self.stage == 1 or self.stage == 4:
            raw_images, np_action, raw_proprio, instruction = self._load_from_raw_traj(meta[0], meta[1], meta[2])
            cur_image, replay_params = self.processor.preprocess_image(raw_images[0]) # [H,W,3]->[3,H,W]
            final_images = [cur_image, *[self.processor.preprocess_image(img)[0] for img in raw_images[1:]]]
            final_images = torch.stack(final_images)

            final_action = self.processor.preprocess_action(np_action)  # 42
            final_proprio = self.processor.preprocess_proprio(raw_proprio)
            item = {
                'cur_images': final_images,
                'cur_actions': final_action,
                'cur_proprios': final_proprio,
                'instruction': instruction,
                'traj_path': meta[0],
                'cur_idx': meta[1],
            }
            return item


        raw_images, subgoals, np_action, raw_proprio, instruction = self._load_from_raw_traj(meta[0], meta[1], meta[2])
        cur_image, replay_params = self.processor.preprocess_image(raw_images[0])
        final_images = [cur_image, *[self.processor.preprocess_image(img)[0] for img in raw_images[1:]]]
        subgoals = [self.processor.preprocess_image(img, replay_params)[0] for img in subgoals]
        final_images = torch.stack(final_images)
        subgoals = torch.stack(subgoals)
        
        final_action = self.processor.preprocess_action(np_action) # 42
        final_proprio = self.processor.preprocess_proprio(raw_proprio)
        item = {
            'sub_goals': subgoals,
            'cur_images': final_images,
            'cur_actions': final_action,
            'cur_proprios': final_proprio,
            'instruction': instruction,
            'traj_path': meta[0],
            'cur_idx': meta[1],
        }
        return item

class LiberoAgent(object):
    def __init__(self, processor, use_ac = True, action_step_ac=8, action_chunk=12):
        super().__init__()
        self.use_ac = use_ac
        self.constant = 10000
        self.processor = processor
        self.policy = None
        self.cnt = 0

        self.action_step_ac = action_step_ac
        self.action_chunk = action_chunk

        self.action_queue = deque(maxlen=action_chunk)

    def set_policy(self, policy):
        assert hasattr(policy, 'generate') and callable(getattr(policy, 'generate')), \
        "The policy must have a callable 'generate' method."
        self.policy = policy

    def _init_action_chunking(self, eval_horizon: int=600, num_samples: int=1):
        self.all_time_actions = np.ones([num_samples, eval_horizon, eval_horizon+50, 7]) * self.constant
    
    def get_ac_action(self, actions, t: int, k: float=0.25):
        B, N, D = actions.shape

        self.all_time_actions[:, [t], t:t+N] = np.expand_dims(actions, axis=1)   # B, horizon, horizon+ac_num, 7
        actions_for_curr_step = self.all_time_actions[:, :, t]  # B, horizon, 7
        actions_populated = np.all(actions_for_curr_step != self.constant, axis=-1)  # B, horizon
        actions_for_curr_step = actions_for_curr_step[actions_populated].reshape(B, -1, D)  # B, N, 7
        exp_weights = np.exp(-k * np.arange(actions_for_curr_step.shape[1]))  # N, 1
        exp_weights = (exp_weights / exp_weights.sum()).reshape(1, -1, 1)
        actions = (actions_for_curr_step * exp_weights).sum(axis=1)
        actions[..., -1] = np.sign(actions[..., -1])
        return actions

    def get_action(self, agent_view_images, wrist_view_images, raw_proprio, instruction, t=-1):

        if len(self.action_queue) == 0:
            agent_view_images = torch.stack([self.processor.preprocess_image(image)[0] for image in agent_view_images]).unsqueeze(1)
            wrist_view_images = torch.stack([self.processor.preprocess_image(image)[0] for image in wrist_view_images]).unsqueeze(1)
            final_images = torch.cat([agent_view_images, wrist_view_images], dim=1)
            final_proprio = torch.stack([self.processor.preprocess_proprio(proprio) for proprio in raw_proprio])
            batch = {
                'cur_images': final_images,
                'cur_proprios': final_proprio,
                'instruction': instruction,
            }

            device = 'cuda:0'
            dtype = torch.float32
            batch = utils.process_inputs(device, dtype, batch)

            actions, _ = self.policy.generate(**batch)
            actions = self.processor.postprocess_action(actions)
            if self.use_ac:
                if t == 1:
                    print(f'use_ac:{self.use_ac}; action_step_ac:{self.action_step_ac}')
                assert t >= 0, f"Invalid value for t: {t}. In action chunking, t must be equal to current rollout step."
                # B,T,D = actions.shape()
                actions = actions[:,:self.action_step_ac,:]
                smoothed_actions = self.get_ac_action(actions, t)
                actions = smoothed_actions
                return actions
            else :
                if t == 1:
                    print(f'use_ac:{self.use_ac}')
                actions = np.transpose(actions, (1, 0, 2))
                self.action_queue.extend(actions[:self.action_chunk])
        actions = self.action_queue.popleft()
        return actions

def build_libero_processor(dataset_path, img_size=224, training=True, stage=0,few_num=5):
    processor = LiberoProcessor(dataset_path=dataset_path, img_size=img_size, training=training, stage=stage,few_num=few_num)
    return processor

def build_libero_dataloader(dataset_path, processor, chunk_length=6, recursive_step=4, rec_plan_coef=0.5,stage=0,few_num=5,
                        batch_size=2, num_workers=2, shuffle=True, pin_mem=True, drop_last=True, 
                        world_size=1, global_rank=0):
    
    train_dataset = LiberoDataset(dataset_path=dataset_path, processor=processor, chunk_length=chunk_length,
                                  recursive_step=recursive_step, rec_plan_coef=rec_plan_coef, stage=stage,few_num=few_num)
    sampler = DistributedSampler(train_dataset, shuffle=shuffle, num_replicas=world_size, rank=global_rank)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers,
                                 sampler=sampler, pin_memory=pin_mem, drop_last=drop_last)
    return train_dataloader

def build_libero_agent(processor, use_ac=True, action_step_ac=8, action_chunk=12):
    agent = LiberoAgent(processor, use_ac, action_step_ac=action_step_ac, action_chunk=action_chunk)
    return agent

def build_libero_engine(dataset_path, img_size=224, # processor
                        recursive_step=4, rec_plan_coef=0.5, stage=0,few_num=5,# dataloader
                        chunk_length=6, batch_size=2, num_workers=2, # dataloader
                        shuffle=True, pin_mem=True, drop_last=True, # dataloader
                        world_size=1, global_rank=0, # dataloader
                        use_ac=True, action_step_ac=8, action_chunk=12 , # agent
                        **kwargs):
    
    processor = build_libero_processor(dataset_path, img_size=img_size, training=True, stage=stage,few_num=few_num)
    train_dataloader = build_libero_dataloader(dataset_path, processor=processor, chunk_length=chunk_length, 
                                               recursive_step=recursive_step, rec_plan_coef=rec_plan_coef,stage=stage,few_num=few_num,
                                               batch_size=batch_size, num_workers=num_workers, 
                                               shuffle=shuffle, pin_mem=pin_mem, drop_last=drop_last,
                                               world_size=world_size, global_rank=global_rank)
    processor = build_libero_processor(dataset_path, img_size=img_size, training=False, stage=stage, few_num=few_num)
    agent = build_libero_agent(processor=processor, use_ac=use_ac, action_step_ac=action_step_ac, action_chunk=action_chunk)
    return train_dataloader, agent

# simulation env
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv, SubprocVectorEnv
import imageio

EPS = 1e-5
LIBERO_DATASETS = {
    "LIBERO_GOAL": ["libero_goal"],
    "LIBERO_OBJECT": ["libero_object"],
    "LIBERO_SPATIAL": ["libero_spatial"],
    "LIBERO_10": ["libero_10"],
    "LIBERO_90": ["libero_90"],
    "libero30": ["libero_goal", "libero_object", "libero_spatial"],
    "libero130": ["libero_goal", "libero_object", "libero_spatial", "libero_10", "libero_90"]
}
LIBERO_DATASETS_HORIZON = {
    "LIBERO_GOAL": 300,
    "LIBERO_OBJECT": 300,
    "LIBERO_SPATIAL": 300,
    "LIBERO_10": 600,
    "LIBERO_90": 300,
    "libero30": 300,
    "libero130": 150,
}

benchmark_dict = benchmark.get_benchmark_dict()

class LIBEROEval():
    def __init__(self, task_suite_name: str, use_ac = True,
                obs_key: list=['agentview_image', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot0_eef_pos', 'robot0_eef_quat'],
                data_statistics: dict=None, logger = None, eval_horizon: int=600, camera_heights=256, camera_widths=256,
                num_episodes: int=10, eval_freq: int=10, seed: int=42, rank: int=0):
        
        self.task_suite_name = task_suite_name
        self.task_list = LIBERO_DATASETS[self.task_suite_name]
        self.task_suite_list = [benchmark_dict[task]() for task in self.task_list]
        self.obs_key = obs_key
        self.data_statistics = data_statistics
        self.eval_horizon = eval_horizon
        self.num_episodes = num_episodes
        self.eval_freq = eval_freq
        self.logger = logger
        self.seed = seed
        self.rank = rank
        self.use_ac = use_ac
        self.camera_heights = camera_heights
        self.camera_widths = camera_widths

    def _make_dir(self, save_path):
        if self.rank == 0:
            task_suite_name = self.task_suite_name
            path = os.path.join(save_path, task_suite_name)
            if not os.path.exists(path):
                os.makedirs(path)
            self.base_dir = path
    
    def _init_env(self, task_suite, task_id: int=0):
        # get task information and env args
        task = task_suite.get_task(task_id)
        task_name = task.name
        task_description = task.language
        task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
        print(f"[info] retrieving task {task_id} from suite {self.task_suite_name}, the " + \
                f"language instruction is {task_description}, and the bddl file is {task_bddl_file}")

        # step over the environment
        env_args = {
            "bddl_file_name": task_bddl_file,
            "camera_heights": self.camera_heights,
            "camera_widths": self.camera_widths
        }
        
        # init thesubprocess vector environment
        env_num = self.num_episodes
        env = SubprocVectorEnv(
            [lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)]
        )
        
        # environment reset 
        env.seed(self.seed + 100)
        env.reset()
        init_states = task_suite.get_task_init_states(task_id) # for benchmarking purpose, we fix the a set of initial states
        init_state_id = np.arange(self.num_episodes) % init_states.shape[0]
        obs = env.set_init_state(init_states[init_state_id])
        
        # return the environment
        env_dict = {}
        env_dict['env'] = env
        env_dict['language_instruction'] = task_description
        env_dict['obs'] = obs
            
        return env_dict
    
    def _log_results(self, metrics: dict, steps: int):
        if self.logger is None:
            # just print out and save the results and pass
            print(metrics)
            save_name = os.path.join(self.base_dir, 'results.json')
            with open(save_name, 'a+') as f:
                line = json.dumps(metrics)
                f.write(line+'\n')
        else:
            # log the results to the logger
            self.logger.log_metrics(metrics, steps)
            self.logger.save_metrics(metrics, steps, self.base_dir)
    
    def raw_obs_to_stacked_obs(self, obs, lang):
        env_num = len(obs)
        data = {
            "obs": {},
            "lang": lang,
        }
        for key in self.obs_key:
            data["obs"][key] = []       
        for i in range(env_num):
            for key in self.obs_key:
                data['obs'][key].append(obs[i][key])
        for key in data['obs']:
            data['obs'][key] = np.stack(data['obs'][key])
        return data
    

    def _rollout(self, task_suite, policy, task_id: int=0):
        
        if self.use_ac:
            policy._init_action_chunking(eval_horizon=self.eval_horizon, num_samples=self.num_episodes)
        
        env = self._init_env(task_suite, task_id)
        lang = env['language_instruction']
        obs = env['obs']
        
        for t in range(5):
            init_action = np.array([[0.,0.,0.,0.,0.,0.,-1.]]).repeat(self.num_episodes, axis=0)
            obs, reward, done, info = env['env'].step(init_action)

        images = []
        for t in tqdm(range(self.eval_horizon), desc=f'{lang}'):
            # get current state
            data = self.raw_obs_to_stacked_obs(obs, lang)
            obs, lang = data['obs'], data['lang']
            gripper_qpos = obs['robot0_gripper_qpos']
            eef_pos = obs['robot0_eef_pos']
            eef_quat = obs['robot0_eef_quat']
            agent_view = np.flip(np.flip(obs['agentview_image'], 1), 2)
            wrist_view = obs['robot0_eye_in_hand_image']
            proprios = np.concatenate([gripper_qpos, eef_pos, eef_quat], axis=-1)
            lang_instruction = [lang] * self.num_episodes
            
            # get action
            action = policy.get_action(agent_view, wrist_view, proprios, lang_instruction, t)

            # record the video
            B, H, W, C = agent_view.shape
            images.append(agent_view.reshape(B * H, W, C))
            
            # step
            obs, reward, done, info = env['env'].step(action)
            if done.all():
                break
        save_path = f'{self.base_dir}/{lang}.mp4'
        self._save_video(save_path, images, done, fps=30)
        
        num_success = 0
        for k in range(self.num_episodes):
                num_success += int(done[k])
        avg_succ_rate = num_success / self.num_episodes
        
        metrics = {f'sim/{self.task_suite_name}/{lang}': avg_succ_rate}
        self._log_results(metrics, self.step)

        re_metrics = {f'{lang}': avg_succ_rate}

        env['env'].close()
        return avg_succ_rate, re_metrics
    
    def _save_video(self, save_path: str, images: list, done: list, fps=30): 
        imageio.mimsave(save_path, images, fps=fps)

    def eval_episodes(self, policy, steps: int, save_path: str):

        self._make_dir(save_path)
        self.step = steps
        
        rews = []
        re_rollout = {}
        solved = 0
        for task_suite in self.task_suite_list:
            for task_id in tqdm(range(len(task_suite.tasks)), desc="Evaluating..."):
                avg_succ_rate, rollout_results = self._rollout(task_suite, policy, task_id)
                rews.append(avg_succ_rate)
                re_rollout.update(rollout_results)
                if avg_succ_rate > 0:
                    solved = solved + 1

        eval_rewards = sum(rews) / len(rews)
        metrics = {f'sim_summary/{self.task_suite_name}/all': eval_rewards}

        new_data = {
            'overall_success_rate': eval_rewards,
            'overall_solved': solved
        }
        re_rollout.update(new_data)

        self._log_results(metrics, self.step)

        return eval_rewards,re_rollout
    
    def close_env(self):
        for env in self.env:
            env['env'].close()

def eval_libero(agent, result_path, num_episodes=1, seed=42,
                task_suites=["libero_goal", "libero_spatial", "libero_10"]):

    result_dict = {}
    metrics = {}
    for suite_name in task_suites:
        horizon = LIBERO_DATASETS_HORIZON[suite_name]
        evaluator = LIBEROEval(task_suite_name=suite_name, eval_horizon=horizon, 
                           num_episodes=num_episodes, seed=seed)
        eval_rewards, metrics = evaluator.eval_episodes(agent, 0, save_path=result_path)
        result_dict[suite_name] = eval_rewards

    with open(f"{result_path}/results.json", "a+") as f:
        json.dump(result_dict, f, indent=4)

    return metrics


import random
from pathlib import Path


def get_random_hdf5_files(dataset_path, num_files=5):

    dataset_path = Path(dataset_path)
    hdf5_files = []

    for folder in dataset_path.iterdir():
        if folder.is_dir():
            all_hdf5_files = [str(file.resolve()) for file in folder.rglob('*.hdf5')]

            if len(all_hdf5_files) > num_files:
                selected_files = random.sample(all_hdf5_files, num_files)
            else:
                selected_files = all_hdf5_files

            hdf5_files.extend(selected_files)

    return hdf5_files


def print_selected_files(hdf5_files_dict):
    for folder_name, files in hdf5_files_dict.items():
        print(f"\nFolder: {folder_name}")
        print(f"Number of Selected Files: {len(files)}")
        for i, file_path in enumerate(files, 1):
            print(f"  {i}. {Path(file_path).name}")


if __name__ == "__main__":
    dataset_path = "/path/to/libero_10"
    selected_files = get_random_hdf5_files(dataset_path, num_files=5)
    print_selected_files(selected_files)
    total_folders = len(selected_files)
    total_files = sum(len(files) for files in selected_files.values())

