import numpy as np
import torch
import sys, os
import h5py



class ReplayBuffer:

    def __init__(self,batch_size,obs_dims=None,action_dims=None,mem_size=None,dataset=None,n_envs=1,
                discrete_action=False,eps=1e-3,discrete_eps=1e-6, discrete_bins=2, normalise_state=False,
                n_steps=1,sample_returns=False, use_data=True):
        ''' Assume that if a dataset is present we are just creating a buffer
            for the dataset and not comibning dataset data with online data
            which you can create another buffer for'''

        self.batch_size = batch_size
        self.action_dtype = int if discrete_action else torch.float
        self.discrete_eps = discrete_eps
        self.discrete_bins = 3
        self.sample_returns = sample_returns

        if dataset:
            self.n_steps = n_steps

            if use_data:
                data_size = dataset['observations'].shape[0]
                if mem_size is None:
                    self.mem_size = data_size
                else:
                    self.mem_size = max(mem_size,data_size)
            else:
                self.mem_size = mem_size

            self.reset_buffer(obs_dims, action_dims)

            self.mean = dataset['observations'].mean(0,keepdims=True)
            self.std = dataset['observations'].std(0,keepdims=True) + eps
        else:
            self.mean = 0
            self.std = 1
            self.n_envs=n_envs
            self.mem_size = mem_size
            self.reset_buffer(obs_dims, action_dims)

    def __len__(self):
        return min(self.mem_size,self.mem_cntr)

    def reset_buffer(self,obs_dims, action_dims):

        action_dtype = int if self.action_dtype==int else float

        self.state_memory = np.zeros((self.mem_size,obs_dims),dtype=float)
        self.next_state_memory = np.zeros((self.mem_size,obs_dims),dtype=float)
        self.diff_state_memory = np.zeros((self.mem_size,obs_dims),dtype=float)
        self.action_memory = np.zeros((self.mem_size,action_dims),dtype=action_dtype)
        self.reward_memory = np.zeros(self.mem_size,dtype=float)
        self.terminal_memory = np.zeros(self.mem_size,dtype=bool)
        self.truncate_memory = np.zeros(self.mem_size,dtype=bool)
        self.mem_cntr = 0

    def to(self, device):
        self.device = device

    def store_transition(self,state,next_state,action,reward,terminal,truncate,true_next_state=None,mem_idxs=None):

        if getattr(self,'n_envs',None) is None:
            self.n_envs = 1

        if mem_idxs==None:
            num_samples = state.shape[0] - 1
            mem_idxs = [i%self.mem_size for i in range(self.mem_cntr,self.mem_cntr+self.n_envs+num_samples)]
            self.mem_cntr += self.n_envs + num_samples

        if true_next_state is None:
            diff_state = next_state-state
        else:
            diff_state = true_next_state-state

        self.discretise_diff_state(diff_state)

        self.state_memory[mem_idxs] = state
        self.next_state_memory[mem_idxs] = next_state
        self.action_memory[mem_idxs] = action
        self.reward_memory[mem_idxs] = reward
        self.terminal_memory[mem_idxs] = terminal
        self.truncate_memory[mem_idxs] = truncate
        self.diff_state_memory[mem_idxs] = diff_state
    

    def sample(self,sample_range=None,min_idx=None,max_idx=None,batch_size=None,rng=None,entire=False, batch_idx=None,raw_states=False):

        if batch_idx is None:
            #mem_size = len(self)
            batch_size = self.batch_size if batch_size is None else batch_size

            if sample_range is None:
                min_idx = 0 if min_idx is None else min_idx
                max_mem_size = len(self)
                mem_size = min(max_mem_size,max_idx) if max_idx is not None else max_mem_size
                sample_range = np.arange(min_idx,mem_size)

            if entire:
                batch_idx = sample_range
            else:
                if rng is not None:
                    batch_idx = rng.choice(sample_range,batch_size,replace=False)
                else:
                    batch_idx = np.random.default_rng().choice(sample_range,
                                                               batch_size,
                                                               replace=False)


        states = torch.tensor(self.state_memory[batch_idx],dtype=torch.float).to(self.device)
        next_states = torch.tensor(self.next_state_memory[batch_idx],dtype=torch.float).to(self.device)
        diff_states = torch.tensor(self.diff_state_memory[batch_idx],dtype=int).to(self.device)
        actions = torch.tensor(self.action_memory[batch_idx],dtype=self.action_dtype).to(self.device)
        rewards = torch.tensor(self.reward_memory[batch_idx],dtype=torch.float).to(self.device).unsqueeze(0)
        done_batch = torch.tensor(self.terminal_memory[batch_idx],dtype=bool).to(self.device).unsqueeze(0)



        output =  states,next_states,diff_states,actions,rewards,done_batch

        if self.sample_returns:
            returns = torch.tensor(self.return_memory[batch_idx],dtype=torch.float).to(self.device).unsqueeze(0)
            output += returns,

        return output + (batch_idx,)
        
    def store_offline_data(self,dataset,normalise_state=True, env_id=''):
        print('loading offline data')

        #If we normalise make sure that diff_state is accounted for - not so easy
        if normalise_state:
            self.normalise_states(dataset)

        dataset_size = dataset['observations'].shape[0]

        new_size = self.mem_cntr+dataset_size
        assert dataset_size+self.mem_cntr <= self.mem_size, f'{dataset_size+self.mem_cntr} is larger than {self.mem_size}'

        self.state_memory[self.mem_cntr:new_size] = dataset['observations'].copy()
        self.next_state_memory[self.mem_cntr:new_size] = dataset['next_observations'].copy()
        self.action_memory[self.mem_cntr:new_size] = dataset['actions'].copy()
        self.terminal_memory[self.mem_cntr:new_size] = dataset['terminals'].copy()
        self.reward_memory[self.mem_cntr:new_size] = dataset['rewards'].copy()


        self.truncate_memory[self.mem_cntr:new_size-1] = np.zeros(dataset['terminals'][1:].shape,dtype=bool)
        dist = np.linalg.norm(self.next_state_memory[:-1,:2]-self.state_memory[1:,:2],axis=-1)
        
        self.truncate_memory[:-1][np.logical_and(dist>1e-6,self.terminal_memory[:-1]==0)] = True
        self.truncate_memory[new_size-1] = True
        print(self.truncate_memory.sum())


        self.diff_state_memory = self.next_state_memory - self.state_memory
        self.discretise_diff_state(self.diff_state_memory)

        
        self.mem_cntr = new_size

        x = self.diff_state_memory[:self.mem_cntr]

        for i in range(self.discrete_bins):
            print(x[x==i].shape)

        if self.sample_returns:
            self.calculate_returns()

        print(self.n_steps)
        print(self.mem_size,self.mem_cntr)




    def discretise_diff_state(self,diff_state):

        eps = self.discrete_eps
        tmp_diff = diff_state.copy()

        if self.discrete_bins == 3:
            diff_state[tmp_diff<-eps]  = 0
            diff_state[tmp_diff>eps]  = 2
            diff_state[np.logical_and(tmp_diff>-eps,tmp_diff<eps)] = 1
        elif self.discrete_bins==2:
            diff_state[tmp_diff<0] = 0
            diff_state[tmp_diff>=0] = 1
        else:
            assert False, 'must provide discrete bin 2 or 3'


    def normalise_states(self,dataset,eps=1e-3):

        dataset['observations'] = (dataset['observations']-self.mean)/self.std
        dataset['next_observations'] = (dataset['next_observations']-self.mean)/self.std

    
    def stitch_samples(self,samples_1,samples_2):
        
        samples = []
        for sample_1, sample_2 in zip(samples_1,samples_2):
            sample_1, sample_2 = sample_1.reshape(self.batch_size,-1), sample_2.reshape(self.batch_size,-1)
            samples.append(torch.vstack([sample_1,sample_2]))

        return samples


    def calculate_returns(self,gamma=.99):

        print('Calculating returns...')
        returns = []

        terminal_split_idxs = np.where(self.terminal_memory)[0]
        trunc_split_idxs = np.where(self.truncate_memory)[0]
        split_idxs = sorted(set(terminal_split_idxs.tolist() + trunc_split_idxs.tolist()))


        if split_idxs[-1] != len(self)-1:
            split_idxs.append(len(self)-1)

        old_idx = 0
        for idx in split_idxs:
            R = 0 
            cum_returns = []
            
            for r in reversed(self.reward_memory[old_idx:idx+1]):
                R = r + R*gamma
                cum_returns.insert(0,R)
            returns.append(cum_returns)
            old_idx = idx+1

        self.return_memory = np.concatenate(returns)
        print('Returns calculated.')

    def reinit_buffer(self, reset_batch_size=5000, rng=None):
        sample_range = np.arange(0,len(self))

        if rng is not None:
            batch_idx = rng.choice(sample_range,reset_batch_size,replace=False)
        else:
            batch_idx = np.random.default_rng().choice(sample_range,
                                                       reset_batch_size,
                                                       replace=False)

        obs_dims = self.state_memory[0].shape[0]
        action_dims = self.action_memory[0].shape[0]
        self.reset_buffer(obs_dims,action_dims)
        self.store_transition(self.state_memory[-reset_batch_size:],
                                self.next_state_memory[-reset_batch_size:],
                                self.action_memory[-reset_batch_size:],
                                self.reward_memory[-reset_batch_size:],
                                self.terminal_memory[-reset_batch_size:])
    
    def create_filepath(self, env_id):

        file_name = 'buffer_data'

        file_path = ''
        for path in ['datasets',env_id]:
            file_path = os.path.join(file_path,path)
            if not os.path.exists(file_path):
                os.makedirs(file_path)

        file_path = os.path.join(file_path,file_name)

        return file_path

    def save_buffer(self,config_dict):
        env_id = config_dict['env_id']

        file_path = self.create_filepath(env_id)

        print(f'saving replay buffer to {file_path}.... ')
        f = h5py.File(file_path, 'w')
        dict_group = f.create_group('dict_data')

        data_keys = ['state_memory','next_state_memory','diff_state_memory','action_memory',
                    'terminal_memory','truncate_memory','reward_memory','mem_cntr']

        for key in data_keys:
            dict_group[key] = getattr(self,key)

        f.close()
    
    def load_buffer(self, config_dict):

        obs_dims = self.state_memory.shape[1]
        action_dims = self.action_memory.shape[1]

        env_id = config_dict['env_id']

        file_path = self.create_filepath(env_id)

        print(f'loading replay buffer.... ')
        f = h5py.File(file_path, 'r')
        
        dict_group_load = f['dict_data']
        
        for key in dict_group_load:
            if key == 'mem_cntr':
                setattr(self,key,int(dict_group_load[key][()]))
            else:
                try:
                    buffer_memory = getattr(self,key)
                except AttributeError:
                    buffer_memory = getattr(self,'truncate_memory')
                data_size = dict_group_load[key][:].shape[0]
                buffer_memory[:data_size] = dict_group_load[key][:]

        assert obs_dims == self.state_memory.shape[1]
        assert action_dims == self.action_memory.shape[1]

        print(f'Buffer loaded, buffer size is {len(self)}')

        self.diff_state_memory = self.next_state_memory - self.state_memory
        self.discretise_diff_state(self.diff_state_memory)
