import numpy as np
from torch import Tensor
from torch.autograd import Variable
import torch
import torch.nn as nn
from torch import optim
# import copy

# Modified from https://github.com/jsyoon0823/GAIN
# and https://github.com/shariqiqbal2810/maddpg-pytorch/
class ReplayBuffer(object):
    def __init__(self, max_step, n_agent, o_dims, a_dims, buf_train=0, buf_flush=0, hint_p=0.5, g_max_step=100):
        self.max_step = int(max_step)
        self.n_agent = n_agent
        # Replay buffer for the training of agents in MARL
        self.obs_buf, self.act_buf, self.rew_buf = [], [], []
        self.don_buf, self.n_obs_buf, self.mask_buf = [], [], []
        for o_dim, a_dim in zip(o_dims, a_dims):
            self.obs_buf.append(np.zeros((max_step, o_dim)))
            self.act_buf.append(np.zeros((max_step, a_dim)))
            self.rew_buf.append(np.zeros(max_step))
            self.n_obs_buf.append(np.zeros((max_step, o_dim)))
            self.don_buf.append(np.zeros(max_step))
            self.mask_buf.append(np.zeros(max_step))

        self.fill_i, self.curr_i = 0, 0

        # Buffer for the training of GAIN
        # Trains GAIN at the end of episodes (currently, an episode has 25 steps)
        self.g_max_step = g_max_step
        self.g_obs_buf, self.g_act_buf, self.g_rew_buf = [], [], []
        self.g_don_buf, self.g_n_obs_buf, self.g_mask_buf = [], [], []
        for o_dim, a_dim in zip(o_dims, a_dims):
            self.g_obs_buf.append(np.zeros((max_step, o_dim)))
            self.g_act_buf.append(np.zeros((max_step, a_dim)))
            self.g_rew_buf.append(np.zeros(max_step))
            self.g_mask_buf.append(np.zeros(max_step))

        self.fill_g, self.curr_g = 0, 0
        self.input_shape = np.array([o_dims, a_dims, [1 for _ in range(n_agent)]])

        # Initialize the discriminator and the generator
        # self.input_shape.sum() is size of training data at tevery timestep
        # We utilize temporal correlation by concatenate 3 consecutive time,
        # i.e., X_t = X_t = (\tau_{t-1}, \tau_{t}, \tau_{t+1}).
        # Therefore, dimension of input is 3 * self.input_shape.sum().
        # Similiarly, output M_t = (m_{t-1}, m_{t}, m_{t+1}), hence output dimension is 3 * n_agent
        self.netD = NetD(3 * self.input_shape.sum(), 3 * n_agent).cuda()
        self.netG = NetG(3 * self.input_shape.sum(), 3 * n_agent).cuda()
        self.opt_D = optim.Adam(self.netD.parameters(), lr=0.001)
        self.opt_G = optim.Adam(self.netG.parameters(), lr=0.001)
        # self.norm save maximum and minimum values
        self.norm = np.zeros((2, 3*self.input_shape.sum()))
        self.buf_train, self.buf_flush, self.hint_p = buf_train, buf_flush, hint_p
        tt = np.ravel(self.input_shape)
        self.pp = tuple(tt[0:i].sum() for i in range(1, len(tt)))

    def __len__(self):
        return self.fill_i

    # Every timestep, data is collected (obtained data)
    def push_imputation(self, g_obs, g_act, g_rew, g_n_obs, g_don, g_mask, preprocess = False):
        imputation_mse_loss = 0
        g_ = int(self.curr_g)
        for a_j in range(self.n_agent):
            self.g_obs_buf[a_j][g_] = g_obs[a_j]
            self.g_act_buf[a_j][g_] = g_act[a_j]
            self.g_rew_buf[a_j][g_] = g_rew[a_j]
            self.g_mask_buf[a_j][g_] = g_mask[a_j]
        self.fill_g = min(self.g_max_step, self.fill_g + 1)
        self.curr_g = (self.curr_g + 1) % self.g_max_step
        # After the end of the episodes
        if self.fill_g == self.g_max_step:
            if preprocess:
                # Trains GAIN
                imputation_mse_loss = self.train_gain()
            else:
                imputation_mse_loss = self.train_gain()
                # Update replay_buffer with completed data
                self.flush_gain()
            self.fill_g, self.curr_g = 0, 0
        return imputation_mse_loss

    # GAIN requires flattended data X_t = (\tau_{t-1}, \tau_{t}, \tau_{t+1})
    # Training data \tau_{t-1}, \tau_{t}, \tau_{t+1} is concateneated and flattened
    def g_buf_to_g_batch(self):
        X_buf, M_buf, c_M_buf, don_buf, e_buf = [], [], [], [], []
        # Currently, implemetation support environments having 25 steps per episode.
        X_part, M_part = np.zeros((25 + 2, self.input_shape.sum())), np.zeros((25 + 2, self.input_shape.sum()))
        c_M_part, don_part = np.zeros((25 + 2, self.n_agent)), np.zeros((25 + 2, self.n_agent))
        e_part = np.zeros(25)
        # Generate X_t
        for t in range(25):
            o_ = np.concatenate([self.g_obs_buf[j][t] * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            a_ = np.concatenate([self.g_act_buf[j][t] * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            r_ = np.array([self.g_rew_buf[j][t] * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            X_part[t+1] = np.concatenate([o_, a_, r_])
            o_ = np.concatenate([np.ones_like(self.g_obs_buf[j][t]) * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            a_ = np.concatenate([np.ones_like(self.g_act_buf[j][t]) * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            r_ = np.array([np.ones_like(self.g_rew_buf[j][t]) * self.g_mask_buf[j][t] for j in range(self.n_agent)])
            M_part[t+1] = np.concatenate([o_, a_, r_])
            c_M_part[t+1] = np.array([self.g_mask_buf[j][t] for j in range(self.n_agent)])
        don_part[25] = np.ones((1, self.n_agent))
        for t in range(25):
            # 3 consecutive data is input of the generator and the discriminastor
            X_buf.append(X_part[t:t+3].flatten()); M_buf.append(M_part[t:t+3].flatten())
            c_M_buf.append(c_M_part[t:t+3].flatten()); don_buf.append(don_part[t:t+3].flatten())
            e_part[t] = max(c_M_part[t].min()*4, c_M_part.min(axis=1)[t:t+3].sum())
        def cast(x): return np.array(x)
        return cast(X_buf), cast(M_buf), cast(c_M_buf), cast(don_buf), cast(e_part)

    # Update complted data
    def g_batch_update(self, G_unnorm, don, c_M, e):
        # G_unnorm: complteded data using GAIN (flattend form)
        # This function reshape G_unnorm to make completed data G_unnorm compatible with MARL algorithm
        c_M_split = np.array_split(c_M, 3, axis=1)
        d_M = (c_M_split[1] + c_M_split[2]) * c_M_split[1]
        all_mask = d_M
        all_don = np.array_split(don, 3, axis=1)[1]
        [_, m1, m2] = np.array_split(G_unnorm, 3, axis=1)
        assert (G_unnorm[0]-np.concatenate([_[0], m1[0], m2[0]])).mean() < 0.0001
        # List contains shape of observations, actions, and rewards for all agents
        list_1 = np.split(m1, self.pp, axis=1)
        list_2 = np.split(m2, self.pp, axis=1)
        all_obs = [list_1[idx + 0 * self.n_agent] for idx in range(self.n_agent)]
        all_act = [list_1[idx + 1 * self.n_agent] for idx in range(self.n_agent)]
        all_rew = [list_1[idx + 2 * self.n_agent].reshape(-1) for idx in range(self.n_agent)]
        all_n_obs = [list_2[idx + 0 * self.n_agent] for idx in range(self.n_agent)]
        self.push(all_obs, all_act, all_rew, all_n_obs, all_don, all_mask)

    # At the end of the episodes, imputed data is generated using GAIN,
    # Then, the replay buffer for MARL is updated
    def flush_gain(self):
        X_, M_, c_M_, don_, e_ = self.g_buf_to_g_batch()
        def tc(x): return torch.FloatTensor(x).cuda()

        e = (e_>=self.buf_flush).nonzero()[0]
        if e.shape[0] == 0:
            return -1

        X, M, c_M, don = X_[e], M_[e], c_M_[e], don_[e]

        X_norm = (X - self.norm[1]) / (self.norm[0] - self.norm[1] + 1e-4)
        X_noise = M * X_norm + (1 - M) * np.random.uniform(0., 1., size=[_ for _ in X_norm.shape])
        G_input = tc(np.hstack((X_noise, c_M))).detach()
        # Generate imputed data
        G_sample = self.netG(G_input)

        # Generate normalized completed data
        G_return = M * X_norm + (1 - M) * G_sample.detach().cpu().numpy()
        # Generate completed data (flattend form)
        G_unnorm = G_return  * (self.norm[0] - self.norm[1] + 1e-4) + self.norm[1]
        self.g_batch_update(G_unnorm, don, c_M, e)

    # Training of GAIN (Algorithm 1)
    def train_gain(self):
        X_, M_, c_M_, don_, e_ = self.g_buf_to_g_batch()
        def tc(x): return torch.FloatTensor(x).cuda()
        self.norm[0] = np.max(np.vstack([X_, self.norm[0]]), axis=0)
        self.norm[1] = np.min(np.vstack([X_, self.norm[1]]), axis=0)

        e = (e_>=self.buf_train).nonzero()[0]
        if e.shape[0] == 0:
            return -1

        X, M, c_M, don = X_[e], M_[e], c_M_[e], don_[e]

        X_norm = (X - self.norm[1]) / (self.norm[0] - self.norm[1] + 1e-4)
        X_noise = M * X_norm  + (1 - M) * np.random.uniform(0., 1., size=[_ for _ in X_norm.shape])
        H_var = np.random.uniform(0., 1., size=[_ for _ in c_M.shape]) < self.hint_p
        Hint = (c_M * H_var + 0.5 * (1. - H_var))

        ### Training of Discriminator
        # Makes inputs of generator
        G_input = tc(np.hstack((X_noise, c_M))).detach()
        H_input = tc(Hint).detach()

        self.opt_D.zero_grad()
        # Obtains imputed data
        G_sample = self.netG(G_input)
        D_input = torch.cat((G_sample, H_input), dim=1)
        # Discriminator outputs probability of being missed
        D_prob = self.netD(D_input)
        # Estimate discriminator loss
        D_loss = - (tc(c_M) * torch.log(D_prob + 1e-8) +
                   (1 - tc(c_M))  * torch.log(1. - D_prob + 1e-8)).mean()
        D_loss.backward()
        self.opt_D.step()

        ### Training of Generator
        self.opt_G.zero_grad()
        G_sample = self.netG(G_input)
        D_input = torch.cat((G_sample, tc(Hint)), dim=1).detach()
        D_prob = self.netD(D_input).detach()
        # Estimate genrerator loss
        G_loss_1 = ((1 - tc(c_M)) * torch.log(D_prob + 1e-8)).mean() / ((1 - tc(c_M)).mean() + 1e-8)
        G_loss_2 = (((tc(X_norm * M)) - G_sample * tc(M))**2).mean() / tc(c_M).mean()
        G_loss = G_loss_1 + 100. * G_loss_2
        G_loss.backward()
        self.opt_G.step()

        return G_loss_2.detach().cpu()


    def push(self, all_obs, all_act, all_rew, all_n_obs, all_don, all_mask):
        n_ent = len(all_obs[0])
        if self.curr_i + n_ent >= self.max_step:
            self.roll()

        s_, e_ = self.curr_i, self.curr_i + n_ent   # use it for the index of the start and end
        for a_j in range(self.n_agent):
            self.obs_buf[a_j][s_:e_] = all_obs[a_j]
            self.act_buf[a_j][s_:e_] = all_act[a_j]
            self.rew_buf[a_j][s_:e_] = all_rew[a_j]
            self.n_obs_buf[a_j][s_:e_] = all_n_obs[a_j]
            self.don_buf[a_j][s_:e_] = all_don.T[a_j]
            self.mask_buf[a_j][s_:e_] = all_mask.T[a_j]

        self.curr_i = (self.curr_i + n_ent) % self.max_step
        self.fill_i = min(self.max_step, self.fill_i + n_ent)

    def roll(self):
        r_idx = self.max_step - self.curr_i
        for a_j in range(self.n_agent):
            self.obs_buf[a_j] = np.roll(self.obs_buf[a_j], r_idx, axis=0)
            self.act_buf[a_j] = np.roll(self.act_buf[a_j], r_idx, axis=0)
            self.rew_buf[a_j] = np.roll(self.rew_buf[a_j], r_idx, axis=0)
            self.n_obs_buf[a_j] = np.roll(self.n_obs_buf[a_j], r_idx, axis=0)
            self.don_buf[a_j] = np.roll(self.don_buf[a_j], r_idx, axis=0)
        self.curr_i = 0
        self.fill_i = self.max_step

    def sample_mask(self, N, idx_a, mask_type=2, device='cuda', norm=True):
        """
            N: number of batch
            idx_a: completed dataset $\hat{D}_{idx_a}$
            mask_type: sample the data when 2 adjacant data is not missed
        """
        # Index those satisfying the condition for the mask-based update
        sample_ind_with_type = np.where(self.mask_buf[idx_a]>=mask_type)[0]
        N_max = min(sample_ind_with_type.shape[0], N)
        ind = np.random.choice(sample_ind_with_type, size=N_max, replace=False)
        # Cast function to makes numpy data into torch data
        cast = lambda x: Variable(Tensor(x), requires_grad=False).to(device)
        if norm:    # reward normalization
            ret_rew = [(self.rew_buf[a_j][ind] - self.rew_buf[a_j][:self.fill_i].mean()) /
                       self.rew_buf[a_j].std() for a_j in range(self.n_agent)]
        else:
            ret_rew = [self.rew_buf[a_j][ind] for a_j in range(self.n_agent)]
        return ([cast(self.obs_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(self.act_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(ret_rew[a_j]) for a_j in range(self.n_agent)],
                [cast(self.n_obs_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(self.don_buf[a_j][ind]) for a_j in range(self.n_agent)],)

    def sample(self, n_sample, device='cuda', norm=True):
        # sample_mask with mask_type=0
        ind = np.random.choice(np.arange(self.fill_i), size=n_sample, replace=False)
        cast = lambda x: Variable(Tensor(x), requires_grad=False).to(device)
        if norm:
            ret_rew = [(self.rew_buf[a_j][ind] - self.rew_buf[a_j][:self.fill_i].mean()) /
                       self.rew_buf[a_j].std() for a_j in range(self.n_agent)]
        else:
            ret_rew = [self.rew_buf[a_j][ind] for a_j in range(self.n_agent)]
        return ([cast(self.obs_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(self.act_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(ret_rew[a_j]) for a_j in range(self.n_agent)],
                [cast(self.n_obs_buf[a_j][ind]) for a_j in range(self.n_agent)],
                [cast(self.don_buf[a_j][ind]) for a_j in range(self.n_agent)],)


# Discriminator networks
class NetD(torch.nn.Module):
    def __init__(self, dim_imp, dim_hint, hidden_dim=128):
        super(NetD, self).__init__()
        self.fc1 = nn.Linear(dim_imp + dim_hint, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, dim_hint)
        self.relu = nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.xavier_normal_(self.fc3.weight)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.sigmoid(self.fc3(x))

# Generator Networks
class NetG(torch.nn.Module):
    def __init__(self, dim_imp, dim_hint, hidden_dim=128):
        super(NetG, self).__init__()
        self.fc1 = nn.Linear(dim_imp + dim_hint, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, dim_imp)
        self.relu = nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.xavier_normal_(self.fc3.weight)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.sigmoid(self.fc3(x))
