# Code reference: https://github.com/Xingyu-Lin/mbpo_pytorch

import torch
torch.set_default_tensor_type(torch.cuda.FloatTensor)
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math
import gzip
import itertools
import time
device = torch.device('cuda')

num_train = 60000  # 60k train examples
num_test = 10000  # 10k test examples
train_inputs_file_path = './MNIST_data/train-images-idx3-ubyte.gz'
train_labels_file_path = './MNIST_data/train-labels-idx1-ubyte.gz'
test_inputs_file_path = './MNIST_data/t10k-images-idx3-ubyte.gz'
test_labels_file_path = './MNIST_data/t10k-labels-idx1-ubyte.gz'

BATCH_SIZE = 100


class StandardScaler(object):
    def __init__(self):
        self.mu = 0
        self.std=1
        self.mu_t = torch.tensor(self.mu).to(device)
        self.std_t = torch.tensor(self.std).to(device)
        # pass

    def fit(self, data):
        """Runs two ops, one for assigning the mean of the data to the internal mean, and
        another for assigning the standard deviation of the data to the internal standard deviation.
        This function must be called within a 'with <session>.as_default()' block.

        Arguments:
        data (np.ndarray): A numpy array containing the input

        Returns: None.
        """
        self.mu = np.mean(data, axis=0, keepdims=True)
        self.std = np.std(data, axis=0, keepdims=True)
        self.std[self.std < 1e-12] = 1.0
        
        self.mu_t = torch.FloatTensor(self.mu).to(device)
        self.std_t = torch.FloatTensor(self.std).to(device)
        

    def transform(self, data):
        """Transforms the input matrix data using the parameters of this scaler.

        Arguments:
        data (np.array): A numpy array containing the points to be transformed.

        Returns: (np.array) The transformed dataset.
        """
        if torch.is_tensor(data):
            return (data- self.mu_t)/self.std_t
        return (data - self.mu) / self.std

    def inverse_transform(self, data):
        """Undoes the transformation performed by this scaler.

        Arguments:
        data (np.array): A numpy array containing the points to be transformed.

        Returns: (np.array) The transformed dataset.
        """
        if torch.is_tensor(data):
            return self.std_t * data + self.mu_t    
        return self.std * data + self.mu


def init_weights(m):
    def truncated_normal_init(t, mean=0.0, std=0.01):
        torch.nn.init.normal_(t, mean=mean, std=std)
        while True:
            cond = ((t < mean - 2 * std)+( t > mean + 2 * std))>0
            if not torch.sum(cond):
                break
            t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
        return t

    if type(m) == nn.Linear or isinstance(m, EnsembleFC):
        input_dim = m.in_features
        truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(input_dim)))
        m.bias.data.fill_(0.0)


class EnsembleFC(nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    ensemble_size: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, ensemble_size: int, weight_decay: float = 0., bias: bool = True) -> None:
        super(EnsembleFC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size
        self.weight = nn.Parameter(torch.Tensor(ensemble_size, in_features, out_features))
        self.weight_decay = weight_decay
        if bias:
            self.bias = nn.Parameter(torch.Tensor(ensemble_size, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        pass


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        w_times_x = torch.bmm(input, self.weight)
        return torch.add(w_times_x, self.bias[:, None, :])  # w times x + b

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


class EnsembleRewardModel(nn.Module):
    def __init__(self, state_size, final_activation, ensemble_size=5, hidden_size=200, learning_rate=1e-3, use_decay=False):
        super(EnsembleRewardModel, self).__init__()
        self.hidden_size = hidden_size
        self.state_dim = state_size
        self.nn1 = EnsembleFC(state_size , hidden_size, ensemble_size)
        self.nn2 = EnsembleFC(hidden_size, hidden_size, ensemble_size)
        self.use_decay = use_decay

        self.output_dim = 1
        self.nn3 = EnsembleFC(hidden_size, 1, ensemble_size)
        self.network_size = ensemble_size
        # self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.apply(init_weights)
        self.swish = Swish()
        self.final_activation = final_activation
        self.scaler = StandardScaler()

    def forward(self, x, ret_log_var=False):
        nn1_output = self.swish(self.nn1(x))
        nn2_output = self.swish(self.nn2(nn1_output))
        nn3_output = self.nn3(nn2_output)

        mean = nn3_output[:, :, :self.output_dim]
        if self.final_activation=='clamp':
            mean = torch.clamp(mean, min=-1.0*10, max=10)
        elif self.final_activation =='sigmoid':
            mean = torch.sigmoid(mean)
        return mean

    def get_decay_loss(self):
        decay_loss = 0.
        for m in self.children():
            if isinstance(m, EnsembleFC):
                decay_loss += m.weight_decay * torch.sum(m.weight**2) / 2.
        return decay_loss

    def get_scalar_reward(self, obs):
        obs_t = self.scaler.transform(obs)
        self.eval()
        with torch.no_grad():
            if not torch.is_tensor(obs_t):
                obs = torch.FloatTensor(obs_t.reshape(-1,self.state_size))

            obs_t = obs_t[None, :, :].repeat([self.network_size, 1, 1])
            obs_t = obs_t.to(device)
            reward = self.forward(obs_t).cpu().detach().numpy()
            model_idx = np.random.choice(self.network_size,size=(obs_t.shape[1]))
            reward = reward[model_idx,np.arange(obs_t.shape[1]),:].flatten()
        # import ipdb;ipdb.set_trace()
        self.train()
        return reward


    def r(self, obs):
        # import ipdb;ipdb.set_trace()
        obs_t = self.scaler.transform(obs)
        states_ens = obs_t[None, :, :].repeat([self.network_size, 1, 1])
        return self.forward(states_ens)

    # def loss(self, mean, logvar, labels, inc_var_loss=True):
    #     """
    #     mean, logvar: Ensemble_size x N x dim
    #     labels: N x dim
    #     """
    #     assert len(mean.shape) == len(logvar.shape) == len(labels.shape) == 3
    #     inv_var = torch.exp(-logvar)
    #     if inc_var_loss:
    #         # Average over batch and dim, sum over ensembles.
    #         mse_loss = torch.mean(torch.mean(torch.pow(mean - labels, 2) * inv_var, dim=-1), dim=-1)
    #         var_loss = torch.mean(torch.mean(logvar, dim=-1), dim=-1)
    #         total_loss = torch.sum(mse_loss) + torch.sum(var_loss)
    #     else:
    #         mse_loss = torch.mean(torch.pow(mean - labels, 2), dim=(1, 2))
    #         total_loss = torch.sum(mse_loss)
    #     return total_loss, mse_loss

    # def train(self, loss):
    #     self.optimizer.zero_grad()
    #     loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar)
    #     if self.use_decay:
    #         loss += self.get_decay_loss()
    #     loss.backward()
    #     self.optimizer.step()


class EnsembleDynamicsModel():
    def __init__(self, network_size, elite_size, state_size, action_size, reward_size=1, hidden_size=200, use_decay=False):
        self.network_size = network_size
        self.elite_size = elite_size
        self.model_list = []
        self.state_size = state_size
        self.action_size = action_size
        self.reward_size = reward_size
        self.network_size = network_size
        self.elite_model_idxes = [0,1,2,3,4]
        self.ensemble_model = EnsembleModel(state_size, action_size, reward_size, network_size, hidden_size, use_decay=use_decay)
        self.scaler = StandardScaler()

    def train(self, inputs, labels, batch_size=256, holdout_ratio=0., max_epochs_since_update=5):
        self._max_epochs_since_update = max_epochs_since_update
        self._epochs_since_update = 0
        self._state = {}
        self._snapshots = {i: (None, 1e10) for i in range(self.network_size)}

        num_holdout = int(inputs.shape[0] * holdout_ratio)
        permutation = np.random.permutation(inputs.shape[0])
        inputs, labels = inputs[permutation], labels[permutation]

        train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
        holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]

        self.scaler.fit(train_inputs)
        train_inputs = self.scaler.transform(train_inputs)
        holdout_inputs = self.scaler.transform(holdout_inputs)

        holdout_inputs = torch.from_numpy(holdout_inputs).float().to(device)
        holdout_labels = torch.from_numpy(holdout_labels).float().to(device)
        holdout_inputs = holdout_inputs[None, :, :].repeat([self.network_size, 1, 1])
        holdout_labels = holdout_labels[None, :, :].repeat([self.network_size, 1, 1])

        for epoch in itertools.count():
            e_start = time.time()
            train_idx = np.vstack([np.random.permutation(train_inputs.shape[0]) for _ in range(self.network_size)])
            
            for start_pos in range(0, train_inputs.shape[0], batch_size):
                idx = train_idx[:, start_pos: start_pos + batch_size]
                train_input = torch.from_numpy(train_inputs[idx]).float().to(device)
                train_label = torch.from_numpy(train_labels[idx]).float().to(device)
                losses = []
                mean, logvar = self.ensemble_model(train_input, ret_log_var=True)
                loss, _ = self.ensemble_model.loss(mean, logvar, train_label)
                self.ensemble_model.train(loss)
                losses.append(loss)

            with torch.no_grad():
                holdout_mean, holdout_logvar = self.ensemble_model(holdout_inputs, ret_log_var=True)
                _, holdout_mse_losses = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_labels, inc_var_loss=False)
                holdout_mse_losses = holdout_mse_losses.detach().cpu().numpy()
                sorted_loss_idx = np.argsort(holdout_mse_losses)
                self.elite_model_idxes = sorted_loss_idx[:self.elite_size].tolist()
                break_train = self._save_best(epoch, holdout_mse_losses)
                if break_train:
                    break
            print('epoch: {}, holdout mse losses: {}'.format(epoch, holdout_mse_losses),time.time()-e_start)
        return 0, holdout_mse_losses.mean()


    def train_low_mem(self, inputs, labels, batch_size=256, holdout_ratio=0., max_epochs_since_update=5,max_epochs=None):
        self._max_epochs_since_update = max_epochs_since_update
        self._epochs_since_update = 0
        self._state = {}
        self._snapshots = {i: (None, 1e10) for i in range(self.network_size)}

        num_holdout = int(inputs.shape[0] * holdout_ratio)
        permutation = np.random.permutation(inputs.shape[0])
        inputs, labels = inputs[permutation], labels[permutation]

        train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
        holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]

        self.scaler.fit(train_inputs)
        train_inputs = self.scaler.transform(train_inputs)
        holdout_inputs = self.scaler.transform(holdout_inputs)

        for epoch in itertools.count():
            if(max_epochs is not None):
                if(epoch>max_epochs):
                    break
            e_start = time.time()
            train_idx = np.vstack([np.random.permutation(train_inputs.shape[0]) for _ in range(self.network_size)])
            
            for start_pos in range(0, train_inputs.shape[0], batch_size):
                idx = train_idx[:, start_pos: start_pos + batch_size]
                train_input = torch.from_numpy(train_inputs[idx]).float().to(device)
                train_label = torch.from_numpy(train_labels[idx]).float().to(device)
                losses = []
                mean, logvar = self.ensemble_model(train_input, ret_log_var=True)
                loss, _ = self.ensemble_model.loss(mean, logvar, train_label)
                self.ensemble_model.train(loss)
                losses.append(loss)

            holdout_mean_losses = None
            with torch.no_grad():
                holdout_idx = np.vstack([np.random.permutation(holdout_inputs.shape[0]) for _ in range(self.network_size)])
                ctr = 0.0
                for start_pos in range(0, holdout_inputs.shape[0], batch_size):
                    idx = holdout_idx[:, start_pos: start_pos + batch_size]
                    holdout_input = torch.from_numpy(holdout_inputs[idx]).float().to(device)
                    holdout_label = torch.from_numpy(holdout_labels[idx]).float().to(device)
                    holdout_mean, holdout_logvar = self.ensemble_model(holdout_input, ret_log_var=True)
                    _, holdout_mse_losses = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_label, inc_var_loss=False)
                    if holdout_mean_losses is None:
                        holdout_mean_losses= holdout_mse_losses.detach().cpu().numpy()
                    else:
                        holdout_mean_losses = holdout_mean_losses*(ctr/(ctr+1))+ (holdout_mse_losses.detach().cpu().numpy()/(ctr+1))
                    ctr+=1.0
                    

                sorted_loss_idx = np.argsort(holdout_mean_losses)
                self.elite_model_idxes = sorted_loss_idx[:self.elite_size].tolist()
                break_train = self._save_best(epoch, holdout_mean_losses)
                if break_train:
                    break
            print('epoch: {}, holdout mse losses: {}'.format(epoch, holdout_mean_losses),time.time()-e_start)
        return 0, holdout_mean_losses.mean()

    def _save_best(self, epoch, holdout_losses):
        updated = False
        for i in range(len(holdout_losses)):
            current = holdout_losses[i]
            _, best = self._snapshots[i]
            improvement = (best - current) / best
            if improvement > 0.01:
                self._snapshots[i] = (epoch, current)
                # self._save_state(i)
                updated = True
                # improvement = (best - current) / best

        if updated:
            self._epochs_since_update = 0
        else:
            self._epochs_since_update += 1
        if self._epochs_since_update > self._max_epochs_since_update:
            return True
        else:
            return False

    def predict_batch_t(self,inputs,batch_size=1024, factored=True):
        inputs = self.scaler.transform(inputs)
        
        ensemble_mean, ensemble_var = [], []
        for i in range(0, inputs.shape[0], batch_size):
            input = inputs[i:min(i + batch_size, inputs.shape[0])].float().to(device)

            b_mean, b_var = self.ensemble_model(input, ret_log_var=False)
            ensemble_mean.append(b_mean)
            ensemble_var.append(b_var)
        ensemble_mean = torch.cat(ensemble_mean,dim=1)
        ensemble_var = torch.cat(ensemble_var,dim=1)
        if factored:
            return ensemble_mean, ensemble_var
        else:
            assert False, "Need to transform to numpy"
            mean = torch.mean(ensemble_mean, dim=0)
            var = torch.mean(ensemble_var, dim=0) + torch.mean(torch.square(ensemble_mean - mean[None, :, :]), dim=0)
            return mean, var

    def predict_t(self,inputs,batch_size=1024, factored=True):
        inputs = self.scaler.transform(inputs)
        ensemble_mean, ensemble_var = [], []
        for i in range(0, inputs.shape[0], batch_size):
            input = inputs[i:min(i + batch_size, inputs.shape[0])].float().to(device)
            b_mean, b_var = self.ensemble_model(input[None, :, :].repeat([self.network_size, 1, 1]), ret_log_var=False)
            ensemble_mean.append(b_mean)
            ensemble_var.append(b_var)
        ensemble_mean = torch.cat(ensemble_mean,dim=1)
        ensemble_var = torch.cat(ensemble_var,dim=1)

        if factored:
            return ensemble_mean, ensemble_var
        else:
            assert False, "Need to transform to numpy"
            mean = torch.mean(ensemble_mean, dim=0)
            var = torch.mean(ensemble_var, dim=0) + torch.mean(torch.square(ensemble_mean - mean[None, :, :]), dim=0)
            return mean, var


    def predict(self, inputs, batch_size=1024, factored=True):
        inputs = self.scaler.transform(inputs)
        ensemble_mean, ensemble_var = [], []
        for i in range(0, inputs.shape[0], batch_size):
            input = torch.from_numpy(inputs[i:min(i + batch_size, inputs.shape[0])]).float().to(device)
            b_mean, b_var = self.ensemble_model(input[None, :, :].repeat([self.network_size, 1, 1]), ret_log_var=False)
            ensemble_mean.append(b_mean.detach().cpu().numpy())
            ensemble_var.append(b_var.detach().cpu().numpy())
        ensemble_mean = np.hstack(ensemble_mean)
        ensemble_var = np.hstack(ensemble_var)

        if factored:
            return ensemble_mean, ensemble_var
        else:
            assert False, "Need to transform to numpy"
            mean = torch.mean(ensemble_mean, dim=0)
            var = torch.mean(ensemble_var, dim=0) + torch.mean(torch.square(ensemble_mean - mean[None, :, :]), dim=0)
            return mean, var


class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        x = x * F.sigmoid(x)
        return x


def get_data(inputs_file_path, labels_file_path, num_examples):
    with open(inputs_file_path, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
        bytestream.read(16)
        buf = bytestream.read(28 * 28 * num_examples)
        data = np.frombuffer(buf, dtype=np.uint8) / 255.0
        inputs = data.reshape(num_examples, 784)

    with open(labels_file_path, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
        bytestream.read(8)
        buf = bytestream.read(num_examples)
        labels = np.frombuffer(buf, dtype=np.uint8)

    return np.array(inputs, dtype=np.float32), np.array(labels, dtype=np.int8)

