import os 
import gym
import torch
import cv2
import pickle
import imageio
import numpy as np
from torch.utils.data import Dataset, DataLoader
import h5py

def load_data_from_h5(file_path: str):
    dataset = {}
    with h5py.File(file_path, 'r') as h5_file:

        observations = np.array(h5_file['observations'][:]) 
        dataset["actions"] = np.array(h5_file['actions'][:])            
        dataset["observations"] = observations.transpose(0, 3, 1, 2)  
        dataset["terminals"] = np.array(h5_file["dones"])
        dataset["rewards"] = np.array(h5_file["rewards"])

    return dataset
    
# Custom Dataset class
class ImagePairDataset(Dataset):
    def __init__(self, image_dir, env_name, no_lang = True):
        self.image_dir = image_dir
        if "kitchen" in self.image_dir:
            env = gym.make(env_name)
            self.state_dataset = env.get_dataset()

            self.truncate = None

            self.terminal_point = np.where(self.state_dataset)[:self.truncate]
            self.actions = self.state_dataset["actions"]
            self.states = self.state_dataset["observations"]

            # state_dim = 60
            # num_types = 10
            # dataset = []
            # for i in range(num_types):
            #     pattern = np.zeros(state_dim)
            #     start_index = (i * state_dim) // num_types
            #     end_index = start_index + state_dim // num_types
            #     pattern[start_index:end_index] = 100 * i + np.random.uniform() * i * 100
            #     pattern[:start_index] = -np.random.uniform() * i * 100
            #     pattern[end_index:] = -np.random.uniform()  * 100
            #     dataset.append(pattern)

            # # Convert the list to a numpy array
            # dataset = np.array(dataset)
            
            # self.states = np.repeat(dataset, self.states.shape[0]/num_types, axis = 0)
            # np.random.shuffle(self.states)

            self.terminals = self.state_dataset["terminals"]
            self.rewards = self.state_dataset["rewards"]
            self.idx = list(range(self.rewards.shape[0]))

        elif "Grid" in self.image_dir:
            env = gym.make(env_name)
            with open(os.path.join(self.image_dir, "trajectory_data.pkl"), "rb") as file:
                self.state_dataset = pickle.load(file)

            self.truncate = None

            self.terminal_point = np.where(self.state_dataset)[:self.truncate]
            self.actions = self.state_dataset["actions"]
            self.states = self.state_dataset["observations"]

            # state_dim = 60
            # num_types = 10
            # dataset = []
            # for i in range(num_types):
            #     pattern = np.zeros(state_dim)
            #     start_index = (i * state_dim) // num_types
            #     end_index = start_index + state_dim // num_types
            #     pattern[start_index:end_index] = 100 * i + np.random.uniform() * i * 100
            #     pattern[:start_index] = -np.random.uniform() * i * 100
            #     pattern[end_index:] = -np.random.uniform()  * 100
            #     dataset.append(pattern)

            # # Convert the list to a numpy array
            # dataset = np.array(dataset)
            
            # self.states = np.repeat(dataset, self.states.shape[0]/num_types, axis = 0)
            # np.random.shuffle(self.states)

            self.terminals = self.state_dataset["terminals"]
            self.rewards = self.state_dataset["rewards"]
            self.idx = list(range(self.rewards.shape[0]))

        elif "Crafter" in self.image_dir:
            self.state_dataset = load_data_from_h5("hrl/data/Crafter-partial.h5")
            
            self.truncate = None

            self.terminal_point = np.where(self.state_dataset)[:self.truncate]
            self.actions = self.state_dataset["actions"]
            self.states = self.state_dataset["observations"]

            # state_dim = 60
            # num_types = 10
            # dataset = []
            # for i in range(num_types):
            #     pattern = np.zeros(state_dim)
            #     start_index = (i * state_dim) // num_types
            #     end_index = start_index + state_dim // num_types
            #     pattern[start_index:end_index] = 100 * i + np.random.uniform() * i * 100
            #     pattern[:start_index] = -np.random.uniform() * i * 100
            #     pattern[end_index:] = -np.random.uniform()  * 100
            #     dataset.append(pattern)

            # # Convert the list to a numpy array
            # dataset = np.array(dataset)
            
            # self.states = np.repeat(dataset, self.states.shape[0]/num_types, axis = 0)
            # np.random.shuffle(self.states)

            self.terminals = self.state_dataset["terminals"]
            self.rewards = self.state_dataset["rewards"]
            self.idx = list(range(self.rewards.shape[0]))

        self.image_pairs = self.load_image_pairs()
        self.max_length = 1
        self.no_lang = no_lang


    def load_image_pairs(self):
        image_pairs = []
        trajectories = sorted(os.listdir(self.image_dir))
        
        trajectories = trajectories[:self.truncate]
        if "trajectory_data.pkl" in os.listdir(self.image_dir):
            trajectories.remove("trajectory_data.pkl")

        sorted_list = sorted([int(i) for i in trajectories])
        trajectories = [str(i) for i in sorted_list]
        
        for traj in trajectories:
            traj_path = os.path.join(self.image_dir, traj)
            images = sorted(os.listdir(traj_path), key=lambda x: int(x.split('.')[0]))
            for i in range(len(images) - 1):
                img1_path = os.path.join(traj_path, images[i])
                img2_path = os.path.join(traj_path, images[i+1])
                img1 = cv2.imread(img1_path)
                img2 = cv2.imread(img2_path)
                image_pairs.append((img1, img2))
            image_pairs.append((img2, img2))

        # for traj in trajectories:
        #     traj_path = os.path.join(self.image_dir, traj)
        #     images = sorted(os.listdir(traj_path), key=lambda x: int(x.split('.')[0]))
        #     img1_path = os.path.join(traj_path, "0.png")
        #     img1 = imageio.imread(img1_path)

        # for traj in range(4):
        #     for id in range(100):
        #         image_pairs.append((np.ones_like(img1)* traj * np.random.uniform() + traj * traj + traj, np.ones_like(img1) * traj))

        # for id in range(100):
        #     image_pairs.append((np.zeros_like(img1)* traj, np.zeros_like(img1) * traj))

        return image_pairs

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        img1, img2 = self.image_pairs[idx]
        img1 = torch.tensor(img1, dtype=torch.float32).permute(2, 0, 1)  # (C, H, W)
        img2 = torch.tensor(img2, dtype=torch.float32).permute(2, 0, 1)  # (C, H, W)
        
        state =  torch.tensor(self.states[idx])
        action =  torch.tensor(self.actions[idx])
        reward = torch.tensor(self.rewards[idx])
        terminal = torch.tensor(self.terminals[idx])
        smooth_state = torch.tensor(self.states[idx + 1]) if self.terminals[idx] == False and idx != self.idx[-1]  else torch.tensor(self.states[idx])
        # if not "kitchen" in self.image_dir:
        #     return img1, img2, [], [], [], []
        # else: 
            # return img1, img2, self.actions[:self.terminal_point], self.rewards[:self.terminal_point], self.terminals[:self.terminal_point]
        return img1, img2, state, smooth_state, action, reward, terminal, self.idx[idx]