import torch
torch.set_default_tensor_type(torch.cuda.FloatTensor)
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torch.distributions.normal import Normal
import gzip
import os

device = torch.device('cuda')

num_train = 60000 # 60k train examples
num_test = 10000 # 10k test examples
train_inputs_file_path = './MNIST_data/train-images-idx3-ubyte.gz'
train_labels_file_path = './MNIST_data/train-labels-idx1-ubyte.gz'
test_inputs_file_path = './MNIST_data/t10k-images-idx3-ubyte.gz'
test_labels_file_path = './MNIST_data/t10k-labels-idx1-ubyte.gz'
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6
BATCH_SIZE = 100

class Game_model(nn.Module):
    def __init__(self, state_size, action_size, reward_size, hidden_size=200, learning_rate=1e-3, separate_mean_var=True):
        super(Game_model, self).__init__()
        self.hidden_size = hidden_size
        self.nn1 = nn.Sequential(
            nn.Linear(state_size + action_size, hidden_size),
            Swish()
        )
        self.nn2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            Swish()
        )
        self.nn3 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            Swish()
        )
        self.nn4 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            Swish()
        )

        self.output_dim = state_size + reward_size
        assert not separate_mean_var
        self.separate_mean_var = separate_mean_var
        # Add variance output
        print("MEAN VAR", self.separate_mean_var)
        if separate_mean_var:
            self.mean_layer = nn.Linear(hidden_size, self.output_dim)
            self.var_layer = nn.Linear(hidden_size, self.output_dim)
        else:
            self.nn5 = nn.Linear(hidden_size, self.output_dim * 2)

        self.max_logvar = Variable(torch.ones((1, self.output_dim)).type(torch.FloatTensor) / 2, requires_grad=True).to(device)
        self.min_logvar = Variable(-torch.ones((1, self.output_dim)).type(torch.FloatTensor) * 10, requires_grad=True).to(device)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=0.00005)

    def forward(self, x, ret_log_var=False):
        nn1_output = self.nn1(x)
        nn2_output = self.nn2(nn1_output)
        nn3_output = self.nn3(nn2_output)
        nn4_output = self.nn4(nn3_output)
        if not self.separate_mean_var:
            nn5_output = self.nn5(nn4_output)

            mean = nn5_output[:, :self.output_dim]

            logvar = self.max_logvar - F.softplus(self.max_logvar - nn5_output[:, self.output_dim:])
            logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
        else:
            mean = self.mean_layer(nn4_output)
            logvar = self.var_layer(nn4_output)
            logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
            logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
        if ret_log_var:
            return mean, logvar
        return mean, torch.exp(logvar)

    def loss(self, mean, logvar, labels, inc_var_loss=True):
        inv_var = torch.exp(-logvar)
        if inc_var_loss:
            mse_loss = torch.mean(torch.pow(mean - labels, 2) * inv_var)
            var_loss = torch.mean(logvar)
            total_loss = mse_loss + var_loss
        else:
            mse_loss = nn.MSELoss()
            total_loss = mse_loss(input=mean, target=labels)
        return total_loss

    def train(self, loss):
        self.optimizer.zero_grad()
        loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar)
        loss.backward()
        self.optimizer.step()

class Ensemble_Model():
    def __init__(self, network_size, elite_size, state_size, action_size, reward_size=1, hidden_size=200, separate_mean_var=True):
        self.network_size = network_size
        self.elite_size = elite_size
        self.model_list = []
        self.state_size = state_size
        self.action_size = action_size
        self.reward_size = reward_size
        self.elite_model_idxes = []
        self._state = {}
        self._epochs_since_update = 0
        self._snapshots = {i: (None, 1e10) for i in range(self.network_size)}
        for i in range(network_size):
            self.model_list.append(Game_model(state_size, action_size, reward_size, hidden_size, separate_mean_var=separate_mean_var))

    def train(self, inputs, labels, val_inputs, val_labels, epoch, batch_size=256, max_epochs_since_update=5):
        self._max_epochs_since_update = max_epochs_since_update
        all_losses = np.zeros(len(self.model_list))
        val_losses = []
        #print("INPUTS SHAPE", inputs.shape[0])
        for start_pos in range(0, inputs.shape[0], batch_size):
            input = torch.from_numpy(inputs[start_pos : start_pos + batch_size]).float().to(device)
            label = torch.from_numpy(labels[start_pos : start_pos + batch_size]).float().to(device)
            losses = []
            for model in self.model_list:
                mean, logvar = model(input, ret_log_var=True)
                loss = model.loss(mean, logvar, label,)
                model.train(loss)
                losses.append(model.loss(mean, logvar, label, inc_var_loss=False))
            #print("LOSSES", all_losses, np.array(losses))
            all_losses += np.array(losses).astype(np.float32)
        all_losses /= len(range(0, inputs.shape[0], batch_size))
        val_losses = self.evaluate(val_inputs, val_labels)
        break_train = self._save_best(epoch, val_losses)

        sorted_loss_idx = np.argsort(all_losses)
        self.elite_model_idxes = sorted_loss_idx[:self.elite_size].tolist()
        print("Losses - Val: {}".format(val_losses))
        return break_train
        
    def _save_best(self, epoch, holdout_losses):
        updated = False
        for i in range(len(holdout_losses)):
            current = holdout_losses[i]
            _, best = self._snapshots[i]
            improvement = (best - current) / best
            if improvement > 0.01:
                self._snapshots[i] = (epoch, current)
                #self._save_state(i)
                save_dir = 'saved_models'
                self.save(save_dir)
                updated = True
                improvement = (best - current) / best
                print('epoch {} | updated {} | improvement: {:.4f} | best: {:.4f} | current: {:.4f}'.format(epoch, i, improvement, best, current))
        
        if updated:
            self._epochs_since_update = 0
        else:
            self._epochs_since_update += 1

        if self._epochs_since_update > self._max_epochs_since_update:
            print('[ BNN ] Breaking at epoch {}: {} epochs since update ({} max)'.format(epoch, self._epochs_since_update, self._max_epochs_since_update))
            return True
        else:
            return False

    def evaluate(self, val_inputs, val_labels):
        val_losses = []
        with torch.no_grad():
            val_inputs = torch.from_numpy(val_inputs).float().to(device)
            val_labels = torch.from_numpy(val_labels).float().to(device)
            for model in self.model_list:
                mean, logvar = model(val_inputs, ret_log_var=True)
                loss = model.loss(mean, logvar, val_labels, inc_var_loss=False)
                val_losses.append(loss.item())
            return val_losses
    def predict(self, inputs, batch_size=1024):
        #TODO: change hardcode number to len(?)
        ensemble_mean = np.zeros((self.network_size, inputs.shape[0], self.state_size + self.reward_size))
        ensemble_logvar = np.zeros((self.network_size, inputs.shape[0], self.state_size + self.reward_size))
        for i in range(0, inputs.shape[0], batch_size):
            input = torch.from_numpy(inputs[i:min(i + batch_size, inputs.shape[0])]).float().to(device)
            for idx in range(self.network_size):
                pred_2d_mean, pred_2d_logvar = self.model_list[idx](input)
                ensemble_mean[idx,i:min(i + batch_size, inputs.shape[0]),:], ensemble_logvar[idx,i:min(i + batch_size, inputs.shape[0]),:] \
                    = pred_2d_mean.detach().cpu().numpy(), pred_2d_logvar.detach().cpu().numpy()

        return ensemble_mean, ensemble_logvar

    def save(self, save_dir):
        i = 0
        for model in self.model_list:
            model_path = save_dir + '/model_{}.pt'.format(i)
            torch.save(model.state_dict(), model_path)
            i += 1
        np.save(save_dir + 'elite_idx.npy', self.elite_model_idxes)

    def load(self, load_dir):
        i = 0
        for i in range(len(self.model_list)):
            state_dict = torch.load(os.path.join(load_dir, 'model_{}.pt'.format(i)))
            self.model_list[i].load_state_dict(state_dict)
        self.elite_model_idxes = list(np.load(load_dir + 'elite_idx.npy'))


class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        x = x * torch.sigmoid(x)
        return x

class PriorModel(nn.Module):
    def __init__(self, obs_dim, ac_dim, hidden_dim=128):
        super(PriorModel, self).__init__()
        self.obs_dim = obs_dim
        self.ac_dim = ac_dim
        self.net = nn.Sequential(nn.Linear(obs_dim, hidden_dim),
                                nn.ReLU(),
                                nn.Linear(hidden_dim, hidden_dim),
                                nn.ReLU(),
                                nn.Linear(hidden_dim, ac_dim * 2)) # * 2 for mean, log std
        self.device = torch.device('cuda')
        self.to(self.device)
        self.optim = torch.optim.Adam(self.parameters(), lr=1e-3)
        
    def forward(self, obs):
        obs = obs.to(self.device).float()
        mu, log_std = torch.chunk(self.net(obs), 2, axis=-1)
        return mu, log_std.exp()

    def dist(self, obs):
        mu, std = self(obs)
        mu = torch.tanh(mu) #* self.action_scale + self.action_bias
        return Normal(mu.detach(), std.detach())
    
    def loss(self, obs, acs):
        acs = acs.to(self.device).float()
        mu, std = self(obs)
        return -Normal(mu, std).log_prob(acs).sum(axis=1).mean()

    def evaluate(self, val_obs, val_acs):
        with torch.no_grad():
            val_obs = torch.from_numpy(val_obs)
            val_acs = torch.from_numpy(val_acs)
            loss = self.loss(val_obs, val_acs)
            return loss.mean().item()

    def train(self, num_epochs, obss, acss, batch_size=256):
        inds = np.arange(len(obss))
        num_train = int(len(inds) * 0.8)
        train_inds, val_inds = inds[:num_train], inds[:num_train]
        np.random.shuffle(train_inds)
        np.random.shuffle(val_inds)
        losses = []
        print("NUM TRAIN", num_train)
        for epoch in range(num_epochs):
            epoch_losses = []
            for start_pos in range(0, len(train_inds), batch_size):
                this_inds = train_inds[start_pos:start_pos + batch_size]
                obs = torch.from_numpy(obss[this_inds]).float()
                acs = torch.from_numpy(acss[this_inds]).float()
                self.optim.zero_grad()
                loss = self.loss(obs, acs)
                loss.backward()
                self.optim.step()
                epoch_losses.append(loss.item())
                #print(loss.item())
            val_obs, val_acs = obss[val_inds], acss[val_inds]
            val_loss = self.evaluate(val_obs, val_acs)
            print("EPOCH: ", epoch, "TRAIN LOSS: ", np.mean(epoch_losses), "VAL LOSS: ", val_loss)
            losses.append(np.mean(epoch_losses))

    def sample(self, obs):
        mu, std = self(obs)
        return Normal(mu, std).sample()

    def save(self, save_dir):
        model_path = save_dir + '/prior.pt'
        torch.save(self.state_dict(), model_path)

    def load(self, save_dir):
        self.load_state_dict(torch.load(save_dir + '/prior.pt'))


class NewPriorModel(PriorModel):
    def __init__(self, obs_dim, ac_dim, hidden_dim=128):
        nn.Module.__init__(self)
        self.obs_dim = obs_dim
        self.ac_dim = ac_dim
        self.linear1 = nn.Linear(obs_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, ac_dim)
        self.log_std_linear = nn.Linear(hidden_dim, ac_dim)
        self.device = torch.device('cuda')
        self.to(self.device)
        self.optim = torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, obs):
        obs = obs.to(self.device).float()
        x = F.relu(self.linear1(obs))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std.exp()


# def get_data(inputs_file_path, labels_file_path, num_examples):
#     with open(inputs_file_path, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
#         bytestream.read(16)
#         buf = bytestream.read(28 * 28 * num_examples)
#         data = np.frombuffer(buf, dtype=np.uint8) / 255.0
#         inputs = data.reshape(num_examples, 784)

#     with open(labels_file_path, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
#         bytestream.read(8)
#         buf = bytestream.read(num_examples)
#         labels = np.frombuffer(buf, dtype=np.uint8)

#     return np.array(inputs, dtype=np.float32), np.array(labels, dtype=np.int8)

# def main():
#     # Import MNIST train and test examples into train_inputs, train_labels, test_inputs, test_labels
#     train_inputs, train_labels = get_data(train_inputs_file_path, train_labels_file_path, num_train)
#     test_inputs, test_labels = get_data(test_inputs_file_path, test_labels_file_path, num_test)

#     model = Ensemble_Model(5, 3, 5, 779, 5, 50)
#     for i in range(0, 10000, BATCH_SIZE):
#         model.train(Variable(torch.from_numpy(train_inputs[i:i+BATCH_SIZE])), Variable(torch.from_numpy(train_labels[i:i+BATCH_SIZE])))
#     model.predict(Variable(torch.from_numpy(test_inputs[:1000])))

# if __name__ == '__main__':
#     main()
