import numpy as np
from torch import Tensor
from torch.autograd import Variable
import torch


# ReplayBuffer that use random imputation of data
class ReplayBuffer(object):
    def __init__(self, max_step, n_agent, o_dims, a_dims, buf_train=0,
                 buf_flush=0, hint_p=0.5, g_max_step=100):
        self.max_step = int(max_step)
        self.n_agent = n_agent
        self.obs_buf, self.act_buf, self.rew_buf = [], [], []
        self.don_buf, self.n_obs_buf, self.mask_buf = [], [], []
        for o_dim, a_dim in zip(o_dims, a_dims):
            self.obs_buf.append(np.zeros((max_step, o_dim)))
            self.act_buf.append(np.zeros((max_step, a_dim)))
            self.rew_buf.append(np.zeros(max_step))
            self.n_obs_buf.append(np.zeros((max_step, o_dim)))
            self.don_buf.append(np.zeros(max_step))
            self.mask_buf.append(np.zeros(max_step))

        # index
        self.fill_i, self.curr_i = 0, 0

        self.g_max_step = g_max_step
        self.g_obs_buf, self.g_act_buf, self.g_rew_buf = [], [], []
        self.g_don_buf, self.g_n_obs_buf, self.g_mask_buf = [], [], []
        for o_dim, a_dim in zip(o_dims, a_dims):
            self.g_obs_buf.append(np.zeros((max_step, o_dim)))
            self.g_act_buf.append(np.zeros((max_step, a_dim)))
            self.g_rew_buf.append(np.zeros(max_step))
            self.g_mask_buf.append(np.zeros(max_step))

        self.fill_g, self.curr_g = 0, 0
        self.input_shape = np.array([o_dims, a_dims, [1 for _ in range(n_agent)]])

        self.norm = np.zeros((2, 3*self.input_shape.sum()))
        self.buf_train, self.buf_flush, self.hint_p = buf_train, buf_flush, hint_p
        tmp_shape = np.ravel(self.input_shape)
        # Variable contains shape of the data
        self.total_shape = tuple(tmp_shape[0:i].sum() for i in range(1, len(tmp_shape)))

    def __len__(self):
        return self.fill_i

    # Saves obtained data to the replay buffer
    def push_imputation(self, g_obs, g_act, g_rew, g_n_obs, g_don, g_mask, preprocess=False):
        imputation_mse_loss = 0
        g_ = int(self.curr_g)
        for a_j in range(self.n_agent):
            self.g_obs_buf[a_j][g_] = g_obs[a_j]
            self.g_act_buf[a_j][g_] = g_act[a_j]
            self.g_rew_buf[a_j][g_] = g_rew[a_j]
            self.g_mask_buf[a_j][g_] = g_mask[a_j]
        self.fill_g = min(self.g_max_step, self.fill_g + 1)
        self.curr_g = (self.curr_g + 1) % self.g_max_step
        if self.fill_g == self.g_max_step:      # when episodes end
            # Train GAIN and update replay buffer
            if preprocess:
                imputation_mse_loss = self.train_imputation()
            else:
                imputation_mse_loss = self.train_imputation()
                self.flush_imputation()
            self.fill_g, self.curr_g = 0, 0
        return imputation_mse_loss

    # Flatten obtained data for imputation
    def g_buf_to_g_batch(self):
        X_buf, M_buf, c_M_buf, don_buf, e_buf = [], [], [], [], []
        # Update at the end of the episodes
        # Number of steps are 25
        X_part, M_part = np.zeros((25 + 2, self.input_shape.sum())), np.zeros((25 + 2, self.input_shape.sum()))
        c_M_part, don_part = np.zeros((25 + 2, self.n_agent)), np.zeros((25 + 2, self.n_agent))
        e_part = np.zeros(25)
        for t in range(25):     # for all o, a, r
            o_ = np.concatenate([self.g_obs_buf[j][t] * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            a_ = np.concatenate([self.g_act_buf[j][t] * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            r_ = np.array([self.g_rew_buf[j][t] * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            X_part[t+1] = np.concatenate([o_, a_, r_])
            o_ = np.concatenate([np.ones_like(self.g_obs_buf[j][t]) * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            a_ = np.concatenate([np.ones_like(self.g_act_buf[j][t]) * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            r_ = np.array([np.ones_like(self.g_rew_buf[j][t]) * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            M_part[t+1] = np.concatenate([o_, a_, r_])
            c_M_part[t+1] = np.array([self.g_mask_buf[j][t] for j in range(self.n_agent)])
        don_part[25] = np.ones((1, self.n_agent))
        for t in range(25):
            X_buf.append(X_part[t:t+3].flatten()); M_buf.append(M_part[t:t+3].flatten())
            c_M_buf.append(c_M_part[t:t+3].flatten()); don_buf.append(don_part[t:t+3].flatten())
            e_part[t] = max(c_M_part[t].min()*4, c_M_part.min(axis=1)[t:t+3].sum())
        def cast(x): return np.array(x)
        return cast(X_buf), cast(M_buf), cast(c_M_buf), cast(don_buf), cast(e_part)

    # Save completed data
    def batch_update(self, completed_data, don, c_M, e):
        c_M_split = np.array_split(c_M, 3, axis=1)
        d_M = (c_M_split[1] + c_M_split[2]) * c_M_split[1]
        all_mask = d_M
        all_don = np.array_split(don, 3, axis=1)[1]
        [_, m1, m2] = np.array_split(completed_data, 3, axis=1)
        assert (completed_data[0]-np.concatenate([_[0], m1[0], m2[0]])).mean() < 0.0001
        list_1 = np.split(m1, self.total_shape, axis=1)
        list_2 = np.split(m2, self.total_shape, axis=1)
        # Reshape flattened data
        all_obs = [list_1[idx + 0 * self.n_agent] for idx in range(self.n_agent)]
        all_act = [list_1[idx + 1 * self.n_agent] for idx in range(self.n_agent)]
        all_rew = [list_1[idx + 2 * self.n_agent].reshape(-1) for idx in range(self.n_agent)]
        all_n_obs = [list_2[idx + 0 * self.n_agent] for idx in range(self.n_agent)]
        self.push(all_obs, all_act, all_rew, all_n_obs, all_don, all_mask)

    def flush_imputation(self):
        X_, M_, c_M_, don_, e_ = self.g_buf_to_g_batch()
        def tc(x): return torch.FloatTensor(x).cuda()

        e = (e_>=self.buf_flush).nonzero()[0]
        if e.shape[0] == 0:
            return -1

        X, M, c_M, don = X_[e], M_[e], c_M_[e], don_[e]

        # Normalized variable for the imputation
        # self.norm[1] is minimum value and self.norm[0] is maximum value
        X_norm = (X - self.norm[1]) / (self.norm[0] - self.norm[1] + 1e-4)
        # Random imputation
        X_noise = M * X_norm + (1 - M) * np.random.uniform(0., 1., size=[_ for _ in X_norm.shape])
        completed_return = M * X_norm + (1 - M) * X_noise
        completed_data = completed_return  * (self.norm[0] - self.norm[1] + 1e-4) + self.norm[1]
        self.batch_update(completed_data, don, c_M, e)

    def train_imputation(self):
        X_, M_, c_M_, don_, e_ = self.g_buf_to_g_batch()
        def tc(x): return torch.FloatTensor(x).cuda()
        # Estimate max and min of variables for the random imputation
        self.norm[0] = np.max(np.vstack([X_, self.norm[0]]), axis=0)
        self.norm[1] = np.min(np.vstack([X_, self.norm[1]]), axis=0)

        e = (e_>=self.buf_train).nonzero()[0]
        if e.shape[0] == 0:
            return -1
        return -1

    def push(self, all_obs, all_act, all_rew, all_n_obs, all_don, all_mask):
        n_ent = len(all_obs[0])
        if self.curr_i + n_ent >= self.max_step:
            self.roll()

        s_, e_ = self.curr_i, self.curr_i + n_ent   # use it for the index of the start and end
        for a_j in range(self.n_agent):
            self.obs_buf[a_j][s_:e_] = all_obs[a_j]
            self.act_buf[a_j][s_:e_] = all_act[a_j]
            self.rew_buf[a_j][s_:e_] = all_rew[a_j]
            self.n_obs_buf[a_j][s_:e_] = all_n_obs[a_j]
            self.don_buf[a_j][s_:e_] = all_don.T[a_j]
            self.mask_buf[a_j][s_:e_] = all_mask.T[a_j]
        self.curr_i = (self.curr_i + n_ent) % self.max_step
        self.fill_i = min(self.max_step, self.fill_i + n_ent)

    def roll(self):
        r_idx = self.max_step - self.curr_i
        for a_j in range(self.n_agent):
            self.obs_buf[a_j] = np.roll(self.obs_buf[a_j], r_idx, axis=0)
            self.act_buf[a_j] = np.roll(self.act_buf[a_j], r_idx, axis=0)
            self.rew_buf[a_j] = np.roll(self.rew_buf[a_j], r_idx, axis=0)
            self.n_obs_buf[a_j] = np.roll(self.n_obs_buf[a_j], r_idx, axis=0)
            self.don_buf[a_j] = np.roll(self.don_buf[a_j], r_idx, axis=0)
        self.curr_i = 0
        self.fill_i = self.max_step

    def sample_mask(self, N, idx_a, mask_type=0, device='cuda', norm=True):
        sample_ind_with_type = np.where(self.mask_buf[idx_a]>=mask_type)[0]
        N_max = min(sample_ind_with_type.shape[0], N)
        ind = np.random.choice(sample_ind_with_type, size=N_max, replace=False)
        cast = lambda x: Variable(Tensor(x), requires_grad=False).to(device)
        if norm:
            ret_rew = [(self.rew_buf[a_j][ind] - self.rew_buf[a_j][:self.fill_i].mean()) /
                       self.rew_buf[a_j].std() for a_j in range(self.n_agent)]
        else:
            ret_rew = [self.rew_buf[a_j][ind] for a_j in range(self.n_agent)]
        return ([cast(self.obs_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(self.act_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(ret_rew[a_j]) for a_j in range(self.n_agent)],
                [cast(self.n_obs_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(self.don_buf[a_j][ind]) for a_j in range(self.n_agent)],)

