import pickle

import numpy as np
import torch

from utils import convert_to_tensor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class UnifiedDataset(torch.utils.data.Dataset):
    def __init__(self, path, config, store_gpu=True):
        self.shuffle = config['shuffle']
        self.horizon = config['horizon']
        self.goal_exist = config['goal']
        self.config = config
        self.store_gpu = store_gpu

        # Accept a string or a list of paths
        if not isinstance(path, list):
            path = [path]

        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                try:
                    while True:
                        self.trajs.extend(pickle.load(f))
                except EOFError:
                    pass

        # Initialize data containers
        context_states, context_actions, context_next_states = [], [], []
        context_rewards, query_states, optimal_actions, goal_states = [], [], [], []

        for traj in self.trajs:
            context_states.append(traj['context_states'])
            context_actions.append(traj['context_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])
            query_states.append(traj['query_state'])
            optimal_actions.append(traj['optimal_action'])
            if self.goal_exist:
                goal_states.append(traj['goal'])

        # Convert all to tensors
       
        self.dataset = {
            'context_states': convert_to_tensor(np.array(context_states), self.store_gpu),
            'context_actions': convert_to_tensor(np.array(context_actions), self.store_gpu),
            'context_next_states': convert_to_tensor(np.array(context_next_states), self.store_gpu),
            'context_rewards': convert_to_tensor(np.array(context_rewards)[..., None], self.store_gpu),
            'query_states': convert_to_tensor(np.array(query_states), self.store_gpu),
            'optimal_actions': convert_to_tensor(np.array(optimal_actions), self.store_gpu),
        }

        if self.goal_exist:
            self.dataset['goals'] = convert_to_tensor(np.array(goal_states), self.store_gpu)

        self.zeros = convert_to_tensor(
            np.zeros(config['state_dim'] ** 2 + config['action_dim'] + 1),
            self.store_gpu
        )

    def __len__(self):
        return len(self.dataset['query_states'])

    def __getitem__(self, index):
        res = {
            'context_states': self.dataset['context_states'][index],
            'context_actions': self.dataset['context_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
            'query_states': self.dataset['query_states'][index],
            'optimal_actions': self.dataset['optimal_actions'][index],
            'zeros': self.zeros,
        }

        if self.goal_exist:
            res['goals'] = self.dataset['goals'][index]

        if self.shuffle:
            perm = torch.randperm(self.horizon)
            for key in ['context_states', 'context_actions', 'context_next_states', 'context_rewards']:
                res[key] = res[key][perm]

        return res

    def sub_set(self, index):
        sub = {k: v[:index] for k, v in self.dataset.items()}
        return sub

    def concatenate(self, other_dataset, index=None):
        sub_dataset = other_dataset.sub_set(index) if index is not None else other_dataset.dataset
        for key in self.dataset:
            if key in sub_dataset:
                print('key: ', key)
                print(self.dataset[key].shape)
                print(sub_dataset[key].shape)
                self.dataset[key] = torch.cat([self.dataset[key], sub_dataset[key]], dim=0)

class SubDataset(torch.utils.data.Dataset):
    """Dataset class."""

    def __init__(self, path, config):
        self.shuffle = config['shuffle']
        self.horizon = config['horizon']
        self.store_gpu = config['store_gpu']
        self.goal_exist = config['goal']
        self.config = config
        self.duration = 1000
        self.sub_size = 125

        # if path is not a list
        if not isinstance(path, list):
            path = [path]

        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs += pickle.load(f)
            
        context_states = []
        context_actions = []
        context_next_states = []
        context_rewards = []
        query_states = []
        optimal_actions = []
        goal_states = []
        collect = False

        count = 0

        for idx, traj in enumerate(self.trajs):
            
            context_states.append(traj['context_states'])
            context_actions.append(traj['context_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])

            query_states.append(traj['query_state'])
            optimal_actions.append(traj['optimal_action'])
            if self.goal_exist:
                goal_states.append(traj['goal'])
            count += 1

               

            # if idx % self.duration == 0:
            #     collect = True
                
            # if collect == True:
            #     context_states.append(traj['context_states'])
            #     context_actions.append(traj['context_actions'])
            #     context_next_states.append(traj['context_next_states'])
            #     context_rewards.append(traj['context_rewards'])

            #     query_states.append(traj['query_state'])
            #     optimal_actions.append(traj['optimal_action'])
            #     if self.goal_exist:
            #         goal_states.append(traj['goal'])
            #     count += 1

            #     if count % self.sub_size == 0:
            #         collect = False



        # for traj in self.trajs:
        #     self.context_states.append(traj['context_states'])
        #     self.context_actions.append(traj['context_actions'])
        #     self.context_next_states.append(traj['context_next_states'])
        #     self.context_rewards.append(traj['context_rewards'])
        #     self.query_states.append(traj['query_state'])
        #     self.optimal_actions.append(traj['optimal_action'])
        #     count += 1
        #     if self.goal_exist:
        #         self.goal_states.append(traj['goal'])
        print('count: ', count)


        context_states = np.array(context_states)
        context_actions = np.array(context_actions)
        context_next_states = np.array(context_next_states)
        context_rewards = np.array(context_rewards)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]
        query_states = np.array(query_states)
        optimal_actions = np.array(optimal_actions)
        if self.goal_exist:
            goal_states = np.array(goal_states)
            self.dataset = {
                'query_states': convert_to_tensor(query_states, store_gpu=self.store_gpu),
                'optimal_actions': convert_to_tensor(optimal_actions, store_gpu=self.store_gpu),
                'context_states': convert_to_tensor(context_states, store_gpu=self.store_gpu),
                'context_actions': convert_to_tensor(context_actions, store_gpu=self.store_gpu),
                'context_next_states': convert_to_tensor(context_next_states, store_gpu=self.store_gpu),
                'context_rewards': convert_to_tensor(context_rewards, store_gpu=self.store_gpu),
            }
        else:
            self.dataset = {
                'query_states': convert_to_tensor(query_states, store_gpu=self.store_gpu),
                'optimal_actions': convert_to_tensor(optimal_actions, store_gpu=self.store_gpu),
                'context_states': convert_to_tensor(context_states, store_gpu=self.store_gpu),
                'context_actions': convert_to_tensor(context_actions, store_gpu=self.store_gpu),
                'context_next_states': convert_to_tensor(context_next_states, store_gpu=self.store_gpu),
                'context_rewards': convert_to_tensor(context_rewards, store_gpu=self.store_gpu),
                'goals':  convert_to_tensor(goal_states, store_gpu=self.store_gpu),
            }
        
        
        self.zeros = np.zeros(
            config['state_dim'] ** 2 + config['action_dim'] + 1
        )
        self.zeros = convert_to_tensor(self.zeros, store_gpu=self.store_gpu)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset['query_states'])

    def concatenate(self, dataset, index):
        sub_dataset = dataset.sub_set(index)
        self.dataset['query_states'] = torch.cat((self.dataset['query_states'], sub_dataset['query_states']), dim = 0)
        self.dataset['optimal_actions'] = torch.cat((self.dataset['optimal_actions'], sub_dataset['optimal_actions']), dim = 0)
        self.dataset['context_states'] = torch.cat((self.dataset['context_states'], sub_dataset['context_states']), dim = 0)
        self.dataset['context_next_states'] = torch.cat((self.dataset['context_next_states'], sub_dataset['context_next_states']), dim = 0)
        self.dataset['context_rewards'] = torch.cat((self.dataset['context_rewards'], sub_dataset['context_rewards']), dim = 0)
        self.dataset['context_actions'] = torch.cat((self.dataset['context_actions'], sub_dataset['context_actions']), dim = 0)

        
        # print(self.dataset['goals'])
        # self.dataset['goals'] = torch.cat((self.dataset['goals'], sub_dataset['goals']), dim = 0)
      

    
    def sub_set(self, index):
        res = {
            'query_states': self.dataset['query_states'][:index],
            'optimal_actions': self.dataset['optimal_actions'][:index],
            'context_states': self.dataset['context_states'][:index],
            'context_actions': self.dataset['context_actions'][:index],
            'context_next_states': self.dataset['context_next_states'][:index],
            'context_rewards': self.dataset['context_rewards'][:index],
        }

        return res


    def __getitem__(self, index):
        'Generates one sample of data'
        res = {
            'context_states': self.dataset['context_states'][index],
            'context_actions': self.dataset['context_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
            'query_states': self.dataset['query_states'][index],
            'optimal_actions': self.dataset['optimal_actions'][index],
            'zeros': self.zeros,
        }

        if self.shuffle:
            perm = torch.randperm(self.horizon)
            res['context_states'] = res['context_states'][perm]
            res['context_actions'] = res['context_actions'][perm]
            res['context_next_states'] = res['context_next_states'][perm]
            res['context_rewards'] = res['context_rewards'][perm]

        return res

# # original
# class Dataset(torch.utils.data.Dataset):
#     """Dataset class."""

#     def __init__(self, path, config):
#         self.shuffle = config['shuffle']
#         self.horizon = config['horizon']
#         self.store_gpu = config['store_gpu']
#         self.config = config

#         # if path is not a list
#         if not isinstance(path, list):
#             path = [path]

#         self.trajs = []
#         for p in path:
#             with open(p, 'rb') as f:
#                 self.trajs += pickle.load(f)
            
#         context_states = []
#         context_actions = []
#         context_next_states = []
#         context_rewards = []
#         query_states = []
#         optimal_actions = []
#         count = 0
#         for traj in self.trajs:
#             context_states.append(traj['context_states'])
#             context_actions.append(traj['context_actions'])
#             context_next_states.append(traj['context_next_states'])
#             context_rewards.append(traj['context_rewards'])

#             query_states.append(traj['query_state'])
#             optimal_actions.append(traj['optimal_action'])
#             count += 1
#         print('count: ', count)
#         context_states = np.array(context_states)
#         context_actions = np.array(context_actions)
#         context_next_states = np.array(context_next_states)
#         context_rewards = np.array(context_rewards)
#         if len(context_rewards.shape) < 3:
#             context_rewards = context_rewards[:, :, None]
#         query_states = np.array(query_states)
#         optimal_actions = np.array(optimal_actions)

#         self.dataset = {
#             'query_states': convert_to_tensor(query_states, store_gpu=self.store_gpu),
#             'optimal_actions': convert_to_tensor(optimal_actions, store_gpu=self.store_gpu),
#             'context_states': convert_to_tensor(context_states, store_gpu=self.store_gpu),
#             'context_actions': convert_to_tensor(context_actions, store_gpu=self.store_gpu),
#             'context_next_states': convert_to_tensor(context_next_states, store_gpu=self.store_gpu),
#             'context_rewards': convert_to_tensor(context_rewards, store_gpu=self.store_gpu),
#         }

#         self.zeros = np.zeros(
#             config['state_dim'] ** 2 + config['action_dim'] + 1
#         )
#         self.zeros = convert_to_tensor(self.zeros, store_gpu=self.store_gpu)

#     def __len__(self):
#         'Denotes the total number of samples'
#         return len(self.dataset['query_states'])

#     def __getitem__(self, index):
#         'Generates one sample of data'
#         res = {
#             'context_states': self.dataset['context_states'][index],
#             'context_actions': self.dataset['context_actions'][index],
#             'context_next_states': self.dataset['context_next_states'][index],
#             'context_rewards': self.dataset['context_rewards'][index],
#             'query_states': self.dataset['query_states'][index],
#             'optimal_actions': self.dataset['optimal_actions'][index],
#             'zeros': self.zeros,
#         }

#         if self.shuffle:
#             perm = torch.randperm(self.horizon)
#             res['context_states'] = res['context_states'][perm]
#             res['context_actions'] = res['context_actions'][perm]
#             res['context_next_states'] = res['context_next_states'][perm]
#             res['context_rewards'] = res['context_rewards'][perm]

#         return res
class Dataset(torch.utils.data.Dataset):
    """Dataset class."""

    def __init__(self, path, store_gpu, config):
        self.shuffle = config['shuffle']
        self.horizon = config['horizon']
        self.goal_exist = config['goal']
        self.config = config
        self.store_gpu = store_gpu
        
        # Check if path is a list; if not, convert it to one
        if not isinstance(path, list):
            path = [path]

        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                while True:
                    try:
                        self.trajs.extend(pickle.load(f))
                    except EOFError:
                        break

        # Initialize lists to store data
        self.context_states = []
        self.context_actions = []
        self.context_next_states = []
        self.context_rewards = []
        self.query_states = []
        self.optimal_actions = []
        self.goal_states = []
        count = 0
        # Load trajectory data, but don't store everything in memory at once
        for traj in self.trajs:
            self.context_states.append(traj['context_states'])
            self.context_actions.append(traj['context_actions'])
            self.context_next_states.append(traj['context_next_states'])
            self.context_rewards.append(traj['context_rewards'])
            self.query_states.append(traj['query_state'])
            self.optimal_actions.append(traj['optimal_action'])
            count += 1
            if self.goal_exist:
                self.goal_states.append(traj['goal'])
        print('count: ', count)
        # Store this data in numpy arrays for easy indexing later
        self.context_states = np.array(self.context_states, dtype=np.float32)
        self.context_actions = np.array(self.context_actions, dtype=np.float32)
        self.context_next_states = np.array(self.context_next_states, dtype=np.float32)
        self.context_rewards = np.array(self.context_rewards, dtype=np.float32)
        self.query_states = np.array(self.query_states, dtype=np.float32)
        self.optimal_actions = np.array(self.optimal_actions, dtype=np.float32)
        if self.goal_exist:
            self.goal_states = np.array(self.goal_states, dtype=np.float32)
        print(self.query_states.shape, self.context_states.shape)
        self.zeros = np.zeros(config['state_dim'] ** 2 + config['action_dim'] + 1)
        self.zeros = convert_to_tensor(self.zeros, store_gpu=self.store_gpu)

    

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.context_states)

    def sub_set(self, index):
        res = {
            'query_states': self.dataset['query_states'][:index],
            'optimal_actions': self.dataset['optimal_actions'][:index],
            'context_states': self.dataset['context_states'][:index],
            'context_actions': self.dataset['context_actions'][:index],
            'context_next_states': self.dataset['context_next_states'][:index],
            'context_rewards': self.dataset['context_rewards'][:index],
            'goals': self.dataset['goals'][:index],
        }

        return res
    
    # def sub_set(self, index):
    #     res = {
    #         'query_states': self.dataset['query_states'][:index],
    #         'optimal_actions': self.dataset['optimal_actions'][:index],
    #         'context_states': self.dataset['context_states'][:index],
    #         'context_actions': self.dataset['context_actions'][:index],
    #         'context_next_states': self.dataset['context_next_states'][:index],
    #         'context_rewards': self.dataset['context_rewards'][:index],
    #     }

    #     return res


    def __getitem__(self, index):
        # Fetch the data for the given index from the individual arrays
        context_state = self.context_states[index]
        context_action = self.context_actions[index]
        context_next_state = self.context_next_states[index]
        context_reward = self.context_rewards[index]
        query_state = self.query_states[index]
        optimal_action = self.optimal_actions[index]
        
        # Optional: goal state if it exists
        if self.goal_exist:
            goal_state = self.goal_states[index]
            res = {
                'query_states': query_state,
                'optimal_actions': optimal_action,
                'context_states': context_state,
                'context_actions': context_action,
                'context_next_states': context_next_state,
                'context_rewards': context_reward,
                'goals': goal_state,
                'zeros': self.zeros,
            }
        else:
            res = {
                'query_states': query_state,
                'optimal_actions': optimal_action,
                'context_states': context_state,
                'context_actions': context_action,
                'context_next_states': context_next_state,
                'context_rewards': context_reward,
                'zeros': self.zeros,
            }

        # Shuffle the data if required
        if self.shuffle:
            perm = torch.randperm(self.horizon)
            res['context_states'] = res['context_states'][perm]
            res['context_actions'] = res['context_actions'][perm]
            res['context_next_states'] = res['context_next_states'][perm]
            res['context_rewards'] = res['context_rewards'][perm]
        # Convert data to tensors and move to GPU if required
        # for key in res:
        #     res[key] = convert_to_tensor(res[key], store_gpu=self.store_gpu)
        
        # # Convert everything to tensors on the CPU first
        # for key in res:
        #     res[key] = torch.tensor(res[key], dtype=torch.float32)

        # # Move to GPU only if store_gpu is set
        # if self.store_gpu:
        #     for key in res:
        #         res[key] = res[key].to('cuda', non_blocking=True)  # Move tensor to GPU


        return res


class ImageDataset(Dataset):
    """"Dataset class for image-based data."""

    def __init__(self, paths, config, transform):
        config['store_gpu'] = False
        super().__init__(paths, config)
        self.transform = transform
        self.config = config

        context_filepaths = []
        query_images = []

        for traj in self.trajs:
            context_filepaths.append(traj['context_images'])
            query_image = self.transform(traj['query_image']).float()
            query_images.append(query_image)

        self.dataset.update({
            'context_filepaths': context_filepaths,
            'query_images': torch.stack(query_images),
        })

    def __getitem__(self, index):
        'Generates one sample of data'
        filepath = self.dataset['context_filepaths'][index]
        context_images = np.load(filepath)
        context_images = [self.transform(images) for images in context_images]
        context_images = torch.stack(context_images).float()

        query_images = self.dataset['query_images'][index]

        res = {
            'context_images': context_images,#.to(device),
            'context_states': self.dataset['context_states'][index],
            'context_actions': self.dataset['context_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
            'query_images': query_images,#.to(device),
            'query_states': self.dataset['query_states'][index],
            'optimal_actions': self.dataset['optimal_actions'][index],
            'zeros': self.zeros,
        }

        if self.shuffle:
            perm = torch.randperm(self.horizon)
            res['context_images'] = res['context_images'][perm]
            res['context_states'] = res['context_states'][perm]
            res['context_actions'] = res['context_actions'][perm]
            res['context_next_states'] = res['context_next_states'][perm]
            res['context_rewards'] = res['context_rewards'][perm]

        return res
