from collections import OrderedDict, deque
import numpy as np
from copy import deepcopy

class Pool(object):

    def __init__(self, variant):

        s_dim = variant['s_dim']
        a_dim = variant['a_dim']

        store_last_n_paths = variant['store_last_n_paths']
        self.paths = deque(maxlen=store_last_n_paths)

        if 'history_horizon' in variant.keys():
            self.history_horizon = variant['history_horizon']
        else:
            self.history_horizon = 0
        self.memory = {
            's': np.zeros([self.history_horizon+1, s_dim]),
            'a': np.zeros([self.history_horizon+1, a_dim]),
            'r': np.zeros([self.history_horizon+1, 1]),
            'terminal': np.zeros([self.history_horizon+1, 1]),
            's_': np.zeros([self.history_horizon+1, s_dim]),
        }

        self.memory_pointer = 0

        self.current_path = {}


    def reset(self):
        [self.current_path.update({key: []}) for key in self.memory.keys()]

    def store(self, **kwargs):
        pass

    def sample(self, **kwargs):
        pass


class LAC_Pool(Pool):

    def __init__(self, variant):

        super(LAC_Pool,self).__init__(variant)

        if 'finite_horizon' in variant.keys():
            if variant['finite_horizon']:
                self.memory.update({'value': np.zeros([self.history_horizon+1, 1])}),
                self.memory.update({'r_N_': np.zeros([self.history_horizon + 1, 1])}),
                self.horizon = variant['value_horizon']
        self.reset()
        self.memory_capacity = variant['memory_capacity']
        self.min_memory_size = variant['min_memory_size']


    def store(self, s, a, d, raw_d, r, terminal, s_):
        transition = {'s': s, 'a': a, 'd': d,'raw_d':raw_d, 'r': np.array([r]), 'terminal': np.array([terminal]), 's_': s_}
        if len(self.current_path['s']) < 1:
            for key in transition.keys():
                self.current_path[key] = transition[key][np.newaxis, :]
        else:
            for key in transition.keys():
                self.current_path[key] = np.concatenate((self.current_path[key],transition[key][np.newaxis,:]))

        if terminal == 1.:
            if 'value' in self.memory.keys():
                r = deepcopy(self.current_path['r'])
                path_length = len(r)
                last_r = self.current_path['r'][-1, 0]
                r = np.concatenate((r, last_r*np.ones([self.horizon+1, 1])), axis=0)
                value = []
                r_N_ = []
                [value.append(r[i:i+self.horizon, 0].sum()) for i in range(path_length)]
                [r_N_.append(r[i + self.horizon+1, 0]) for i in range(path_length)]
                value = np.array(value)
                r_N_ = np.array(r_N_)
                self.memory['value'] = np.concatenate((self.memory['value'], value[:, np.newaxis]), axis=0)
                self.memory['r_N_'] = np.concatenate((self.memory['r_N_'], r_N_[:, np.newaxis]), axis=0)
            for key in self.current_path.keys():
                self.memory[key] = np.concatenate((self.memory[key], self.current_path[key]), axis=0)
            self.paths.appendleft(self.current_path)
            self.reset()
            self.memory_pointer = len(self.memory['s'])

        return self.memory_pointer

    def sample(self, batch_size):
        if self.memory_pointer < self.min_memory_size:
            return None
        else:
            indices = np.random.choice(min(self.memory_pointer, self.memory_capacity)-1-self.history_horizon, size=batch_size, replace=False) \
                      + max(1 + self.history_horizon, 1 + self.history_horizon+self.memory_pointer-self.memory_capacity)*np.ones([batch_size], np.int)
            batch = {}

            for key in self.memory.keys():
                if 's' in key:
                    sample = [self.memory[key][indices-i] for i in range(self.history_horizon + 1)]
                    sample = np.concatenate(sample, axis=1)
                    batch.update({key: sample})
                else:
                    batch.update({key: self.memory[key][indices]})
            return batch


class REINFOCE_Pool(Pool):
    def __init__(self, variant):
        super(REINFOCE_Pool, self).__init__(variant)
        self.memory.update({'c': np.zeros([self.history_horizon + 1, 1]),
                            'c_': np.zeros([self.history_horizon + 1, 1]),
                            'c_T': np.zeros([self.history_horizon + 1, 1]),
                            's_T': np.zeros([self.history_horizon + 1, self.memory['s'].shape[1]]),
                            'C': np.zeros([self.history_horizon + 1, 1]),
                            'L_target': np.zeros([self.history_horizon + 1, 1]),
                            })

        self.constant_baseline = variant['constant_baseline']
        self.memory.pop('r')
        self.reset()
        self.c_bar = variant['c_bar']
        self.target_horizon = variant['target_horizon']

    def reset_memory(self):
        self.memory = {'c': np.zeros([self.history_horizon + 1, 1]),
                       'c_': np.zeros([self.history_horizon + 1, 1]),
                       'C': np.zeros([self.history_horizon + 1, 1]),
                       'c_T': np.zeros([self.history_horizon + 1, 1]),
                       's_T': np.zeros([self.history_horizon + 1, self.memory['s_T'].shape[1]]),
                       's': np.zeros([self.history_horizon + 1, self.memory['s'].shape[1]]),
                       'a': np.zeros([self.history_horizon + 1, self.memory['a'].shape[1]]),
                       'terminal': np.zeros([self.history_horizon + 1, 1]),
                       's_': np.zeros([self.history_horizon + 1, self.memory['s'].shape[1]]),
                       'L_target': np.zeros([self.history_horizon + 1, 1]),
                       }

    def store(self, s, a, norm, norm_, terminal, s_):
        transition = {'s': s,
                      'a': a,
                      'c': np.min((np.array([norm]), np.array([self.c_bar])), axis=0),
                      'c_': np.min((np.array([norm_]), np.array([self.c_bar])), axis=0),
                      'terminal': np.array([terminal]),
                      's_': s_}

        if len(self.current_path['s']) < 1:
            for key in transition.keys():
                self.current_path[key] = transition[key][np.newaxis, :]
        else:
            for key in transition.keys():
                self.current_path[key] = np.concatenate((self.current_path[key], transition[key][np.newaxis, :]))

        if terminal == 1.:
            c = deepcopy(self.current_path['c'])
            path_length = len(c)
            c_T = self.current_path['c_'][-1, 0] * np.ones_like(c)
            s_T = self.current_path['s_'][-1] * np.ones_like(self.current_path['s_'])
            C = []
            c_with_baseline = c - self.constant_baseline * np.ones_like(c)
            # Calculate Sum of Cost
            [C.append(c_with_baseline[i + self.history_horizon + 1:, 0].sum()) for i in range(path_length - 1)]
            C.append(0)
            C = np.array(C)

            #Calculate target for Lyapunov
            L_target = []
            last_c = self.current_path['c_'][-1, 0]
            c = np.concatenate((c, last_c * np.ones([self.target_horizon + 1, 1])), axis=0)
            [L_target.append(c[i:i + self.target_horizon + 1, 0].sum()) for i in range(path_length)]
            L_target = np.array(L_target)

            self.current_path['C'] = C[:, np.newaxis]
            self.current_path['L_target'] = L_target[:, np.newaxis]
            self.current_path['c_T'] = c_T
            self.current_path['s_T'] = s_T
            for key in self.current_path.keys():
                self.memory[key] = np.concatenate((self.memory[key], self.current_path[key]), axis=0)
            self.paths.appendleft(self.current_path)
            self.reset()
        self.memory_pointer = len(self.memory['s'])
        return

    def sample(self, initial_inds=None):
        if initial_inds is None:
            inds = np.arange(self.history_horizon+1, self.memory_pointer, dtype=int)
        else:
            inds = initial_inds
        batch = {}

        for key in self.memory.keys():
            if 's' in key:
                sample = [self.memory[key][inds-i] for i in range(self.history_horizon + 1)]
                sample = np.concatenate(sample, axis=1)
                batch.update({key: sample})
            else:
                batch.update({key: self.memory[key][inds]})

        if initial_inds is None:
            self.reset_memory()
            self.memory_pointer = len(self.memory['s'])

        return batch



