import numpy as np
import torch
import sys, os
import h5py



class ReplayBuffer:

    def __init__(self,batch_size,obs_dims=None,action_dims=None,mem_size=None,dataset=None,n_envs=1,
                discrete_action=False,eps=1e-3, normalise_state=False, sample_next_state=False,
                sample_returns=False, use_data=True):
        ''' Assume that if a dataset is present we are just creating a buffer
            for the dataset and not comibning dataset data with online data
            which you can create another buffer for'''

        self.batch_size = batch_size
        self.action_dtype = int if discrete_action else torch.float
        self.sample_returns = sample_returns

        self.sample_next_state = sample_next_state ##for onestep appraoch whether to use data from dataset for s''

        if dataset:

            if use_data:
                data_size = dataset['observations'].shape[0]
                if mem_size is None:
                    self.mem_size = data_size
                else:
                    self.mem_size = max(mem_size,data_size)
            else:
                self.mem_size = mem_size

            self.reset_buffer(obs_dims, action_dims)

            self.mean = dataset['observations'].mean(0,keepdims=True)
            self.std = dataset['observations'].std(0,keepdims=True) + eps

            self.store_offline_data(dataset,normalise_state)
        else:
            self.mean = 0
            self.std = 1
            self.n_envs=n_envs
            self.mem_size = mem_size
            self.reset_buffer(obs_dims, action_dims)

    def __len__(self):
        return min(self.mem_size,self.mem_cntr)

    def reset_buffer(self,obs_dims, action_dims):

        action_dtype = int if self.action_dtype==int else float

        self.state_memory = np.zeros((self.mem_size,obs_dims),dtype=float)
        self.next_state_memory = np.zeros((self.mem_size,obs_dims),dtype=float)
        self.diff_state_memory = np.zeros((self.mem_size,obs_dims),dtype=float)
        self.action_memory = np.zeros((self.mem_size,action_dims),dtype=action_dtype)
        self.reward_memory = np.zeros(self.mem_size,dtype=float)
        self.terminal_memory = np.zeros(self.mem_size,dtype=bool)
        self.truncate_memory = np.zeros(self.mem_size,dtype=bool)
        self.mem_cntr = 0

    def to(self, device):
        self.device = device

    def store_transition(self,state,next_state,action,reward,terminal,truncate):


        mem_idxs  = self.mem_cntr%self.mem_size
        self.mem_cntr +=1

        self.state_memory[mem_idxs] = state
        self.next_state_memory[mem_idxs] = next_state
        self.action_memory[mem_idxs] = action
        self.reward_memory[mem_idxs] = reward
        self.terminal_memory[mem_idxs] = terminal
        self.truncate_memory[mem_idxs] = truncate
    

    def sample(self,sample_range=None,min_idx=None,max_idx=None,batch_size=None,rng=None,entire=False, batch_idx=None,raw_states=False):

        if batch_idx is None:
            #mem_size = len(self)
            batch_size = self.batch_size if batch_size is None else batch_size

            if sample_range is None:
                min_idx = 0 if min_idx is None else min_idx
                
                if self.sample_next_state:
                    ##because we need next_state we cant index the last element of dataset
                    max_mem_size = len(self)-1
                else:
                    max_mem_size = len(self)

                mem_size = min(max_mem_size,max_idx) if max_idx is not None else max_mem_size
                sample_range = np.arange(min_idx,mem_size)

            if entire:
                batch_idx = sample_range
            else:
                if rng is not None:
                    batch_idx = rng.choice(sample_range,batch_size,replace=False)
                else:
                    batch_idx = np.random.default_rng().choice(sample_range,
                                                               batch_size,
                                                               replace=False)



        if self.sample_next_state:
            trunc_memory = self.truncate_memory[batch_idx]
            batch_idx = batch_idx[trunc_memory==False]

        states = torch.tensor(self.state_memory[batch_idx],dtype=torch.float).to(self.device)
        next_states = torch.tensor(self.next_state_memory[batch_idx],dtype=torch.float).to(self.device)
        actions = torch.tensor(self.action_memory[batch_idx],dtype=self.action_dtype).to(self.device)
        rewards = torch.tensor(self.reward_memory[batch_idx],dtype=torch.float).to(self.device).unsqueeze(0)
        done_batch = torch.tensor(self.terminal_memory[batch_idx],dtype=bool).to(self.device).unsqueeze(0)

        output =  states,next_states,actions,rewards,done_batch


        return output + (batch_idx,)
        
    def store_offline_data(self,dataset,normalise_state=True, env_id=''):
        print('loading offline data')

        #If we normalise make sure that diff_state is accounted for - not so easy
        if normalise_state:
            self.normalise_states(dataset)


        dataset_size = dataset['observations'].shape[0]

        new_size = self.mem_cntr+dataset_size
        assert dataset_size+self.mem_cntr <= self.mem_size, f'{dataset_size+self.mem_cntr} is larger than {self.mem_size}'

        self.state_memory[self.mem_cntr:new_size] = dataset['observations'].copy()
        self.next_state_memory[self.mem_cntr:new_size] = dataset['next_observations'].copy()
        self.action_memory[self.mem_cntr:new_size] = dataset['actions'].copy()
        self.terminal_memory[self.mem_cntr:new_size] = dataset['terminals'].copy()
        self.reward_memory[self.mem_cntr:new_size] = dataset['rewards'].copy()


        self.truncate_memory[self.mem_cntr:new_size-1] = np.zeros(dataset['terminals'][1:].shape,dtype=bool)
        dist = np.linalg.norm(self.next_state_memory[:-1,:2]-self.state_memory[1:,:2],axis=-1)

        self.truncate_memory[:-1][np.logical_and(dist>1e-6,self.terminal_memory[:-1]==0)] = True
        self.truncate_memory[new_size-1] = True

        self.mem_cntr = new_size




    def normalise_states(self,dataset,eps=1e-3):

        dataset['observations'] = (dataset['observations']-self.mean)/self.std
        dataset['next_observations'] = (dataset['next_observations']-self.mean)/self.std


    
    def create_filepath(self, env_id):

        file_name = 'buffer_data'

        file_path = ''
        for path in ['datasets',env_id]:
            file_path = os.path.join(file_path,path)
            if not os.path.exists(file_path):
                os.makedirs(file_path)

        file_path = os.path.join(file_path,file_name)

        return file_path

    def save_buffer(self,config_dict):
        env_id = config_dict['env_id']

        file_path = self.create_filepath(env_id)

        print(f'saving replay buffer to {file_path}.... ')
        f = h5py.File(file_path, 'w')
        dict_group = f.create_group('dict_data')

        data_keys = ['state_memory','next_state_memory''action_memory','terminal_memory',
                        'truncate_memory','reward_memory','mem_cntr']

        for key in data_keys:
            dict_group[key] = getattr(self,key)

        f.close()
    
    def load_buffer(self, config_dict):

        obs_dims = self.state_memory.shape[1]
        action_dims = self.action_memory.shape[1]

        env_id = config_dict['env_id']

        file_path = self.create_filepath(env_id)

        print(f'loading replay buffer.... ')
        f = h5py.File(file_path, 'r')
        
        dict_group_load = f['dict_data']
        
        for key in dict_group_load:
            if key == 'mem_cntr':
                setattr(self,key,int(dict_group_load[key][()]))
            else:
                try:
                    buffer_memory = getattr(self,key)
                except AttributeError:
                    buffer_memory = getattr(self,'truncate_memory')
                data_size = dict_group_load[key][:].shape[0]
                buffer_memory[:data_size] = dict_group_load[key][:]

        assert obs_dims == self.state_memory.shape[1]
        assert action_dims == self.action_memory.shape[1]

        print(f'Buffer loaded, buffer size is {len(self)}')
