import h5py
import pandas as pd
import glob
import os
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import hashlib
import io
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import copy
import multiprocessing
import pdb

from .util.metadata_helper import load_metadata
from .util.hdf5_loader import load_data, load_camera_imgs
import argparse

class ACTION_MISMATCH:
    ERROR = 0
    PAD_ZERO = 1
    CLEAVE = 2

class STATE_MISMATCH:
    ERROR = 0
    PAD_ZERO = 1
    CLEAVE = 2

def default_loader_hparams():
    return {
            'target_adim': 4,
            'target_sdim': 5,
            'state_mismatch': 3,     # TODO make better flag parsing
            'action_mismatch': 3,   # TODO make better flag parsing
            'img_size': [64, 64],
            'cams_to_load': [0],
            'impute_autograsp_action': True,
            'load_annotations': False,
            'zero_if_missing_annotation': False, 
            'load_T': 0                               # TODO implement error checking here for jagged reading
            }

class HParams:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def __repr__(self):
        return f"HParams({self.__dict__})"

    def get(self, key, default=None):
        return getattr(self, key, default)

    def set(self, key, value):
        setattr(self, key, value)

    def to_dict(self):
        return self.__dict__


class RobonetSimpleDataset(Dataset):
    def __init__(self, base_path, train=True, hparams=HParams(**default_loader_hparams())):
        self.base_path = base_path
        self.hparams = hparams
        self.metadata = load_metadata(base_path)
        files = self.metadata.files
        # 从txt文件中读取测试集文件名
        test_set = set()
        # 获取该文件所处的目录
        base_path = os.path.dirname(files[0])
        with open('{}/test_set.txt'.format(base_path), 'r') as f:
            for line in f:
                test_set.add(line.strip())
        if train:
            files = [f for f in files if f.split('/')[-1] not in test_set]
        else:
            files = [f for f in files if f.split('/')[-1] in test_set]
        print('train:', train)
        print('files:', len(files))
        self.files = files

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        file = self.files[index]
        if self.hparams.load_annotations:
            imgs, actions, states, annot = load_data(file, self.metadata.get_file_metadata(file), self.hparams)
        else:
            imgs, actions, states = load_data(file, self.metadata.get_file_metadata(file), self.hparams)
    
        # print('actions', actions.shape)
        # print('states', states.shape)
        # print('images', imgs.shape)
        return imgs, actions, states