import os
import gymnasium as gym
import dsrl
import numpy as np
from model.data.dataset import Dataset
import h5py


class DSRLDataset(Dataset):
    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5, data_location=None, cost_scale=1., ratio=1.0):
        
        dataset_dict = {}
        
        if data_location =="data/toydata[19:1].hdf5": # toydata[10:1].hdf5 // toydata[5:1].hdf5
            print('=========Data loading=========')
            print('Load data from:', data_location)
            f = h5py.File(data_location, 'r')
            dataset_dict["observations"] = np.array(f['state'])
            dataset_dict["actions"] = np.array(f['action'])
            dataset_dict["next_observations"] = np.array(f['next_state'])
            dataset_dict["rewards"] = np.array(f['reward'])
            dataset_dict["dones"] = np.array(f['done'])
            dataset_dict['costs'] = np.array(f['h'])
            dataset_dict['vio'] = np.array(f['cost'])
            print('env_max_episode_steps', env._max_episode_steps)
            print('mean_episode_reward', env._max_episode_steps * np.mean(dataset_dict['rewards']))
            print('mean_episode_cost', env._max_episode_steps * np.mean(violation))

        if data_location=="data/first_carpush1.hdf5" or data_location=="data/second_carpush1.hdf5":
            print("New collect data")
            f = h5py.File(data_location, 'r')
            dataset_dict["observations"] = np.array(f['observations'])
            dataset_dict["actions"] = np.array(f['actions'])
            dataset_dict["next_observations"] = np.array(f['next_observations'])
            dataset_dict["rewards"] = np.array(f['rewards'])
            dataset_dict["dones"] = np.array(f['dones'])
            dataset_dict['costs'] = np.array(f['costs'])
            
            print('max_episode_reward', env.max_episode_reward, 
                'min_episode_reward', env.min_episode_reward,
                'mean_episode_reward', env._max_episode_steps * np.mean(dataset_dict['rewards']))
            print('max_episode_cost', env.max_episode_cost, 
                'min_episode_cost', env.min_episode_cost,
                'mean_episode_cost', env._max_episode_steps * np.mean(dataset_dict['costs']))
            dataset_dict['costs'] = np.where(dataset_dict['costs']>0, 1*cost_scale, -1)

        else:
            # Benchmark
            if ratio == 1.0:
                dataset_dict = env.get_dataset()
            else:
                _, dataset_name = os.path.split(env.dataset_url)
                file_list = dataset_name.split('-')
                ratio_num = int(float(file_list[-1].split('.')[0]) * ratio)
                dataset_ratio = '-'.join(file_list[:-1]) + '-' + str(ratio_num) + '-' + str(ratio) + '.hdf5'
                dataset_dict = env.get_dataset(os.path.join('data', dataset_ratio))
            print('max_episode_reward', env.max_episode_reward, 
                'min_episode_reward', env.min_episode_reward,
                'mean_episode_reward', env._max_episode_steps * np.mean(dataset_dict['rewards']))
            print('max_episode_cost', env.max_episode_cost, 
                'min_episode_cost', env.min_episode_cost,
                'mean_episode_cost', env._max_episode_steps * np.mean(dataset_dict['costs']))
            print('data_num', dataset_dict['actions'].shape[0])
            dataset_dict['dones'] = np.logical_or(dataset_dict["terminals"],
                                                dataset_dict["timeouts"]).astype(np.float32)
            del dataset_dict["terminals"]
            del dataset_dict['timeouts']

            dataset_dict['costs'] = np.where(dataset_dict['costs']>0, 1*cost_scale, -1)

        # Balance the dataset
        positive_indices = np.where(dataset_dict['costs'] >0)[0]
        negative_indices = np.where(dataset_dict['costs'] <0)[0]
        print('safe_indices:',negative_indices)
        print('unsafe_indices:',positive_indices)
        min_count = min(len(positive_indices), len(negative_indices))
        print('MIN count:', min_count)
        if min_count > 0:  # Ensure there is at least one sample in each class

            balanced_positive_indices = np.random.choice(positive_indices, int(1e5+5e4), replace=False)
            balanced_negative_indices = np.random.choice(negative_indices, int(1e6+2e5), replace=False)

            repeated_positive_indices = np.tile(balanced_positive_indices, 24)
            print("Use Unsafe count", len(repeated_positive_indices),"Use Safe count", len(balanced_negative_indices))
            balanced_indices = np.concatenate([repeated_positive_indices, balanced_negative_indices])
            print("Total count", len(balanced_indices))
            np.random.shuffle(balanced_indices)  # Shuffle to mix the positive and negative samples

            for key in dataset_dict.keys():
                dataset_dict[key] = dataset_dict[key][balanced_indices]

        if clip_to_eps:
            lim = 1 - eps
            dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)

        for k, v in dataset_dict.items():
            dataset_dict[k] = v.astype(np.float32)
        
        dataset_dict["masks"] = 1.0 - dataset_dict['dones']
        print(len(dataset_dict["masks"]))
        del dataset_dict['dones']

        super().__init__(dataset_dict)
