
import os
import sys
import glob
import time
from datetime import datetime
import warnings
warnings.simplefilter("ignore", UserWarning)


import torch
import torch.nn as nn
from torch.distributions import Normal
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader

import itertools
import math
import numpy as np


from spectral import SpectralScoreEstimator


################################## set device ##################################
# print("============================================================================================")
# # set device to cpu or cuda
# device = torch.device('cpu')
# if(torch.cuda.is_available()):
#     device = torch.device('cuda:0')
#     torch.cuda.empty_cache()
#     print("Device set to : " + str(torch.cuda.get_device_name(device)))
# else:
#     print("Device set to : cpu")
# print("============================================================================================")


################################## PPO Policy ##################################

class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.nextstates = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.nextstates[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]


class ReplayBuffer():
    def __init__(self, action_prob_exist, max_size, state_dim, num_action):
        self.max_size = max_size
        self.data_idx = 0
        self.action_prob_exist = action_prob_exist
        self.data = {}

        self.data['state'] = np.zeros((self.max_size, state_dim))
        self.data['action'] = np.zeros((self.max_size))
        self.data['reward'] = np.zeros((self.max_size, 1))
        self.data['next_state'] = np.zeros((self.max_size, state_dim))
        self.data['done'] = np.zeros((self.max_size, 1))
        if self.action_prob_exist:
            self.data['log_prob'] = np.zeros((self.max_size, 1))

    def put_data(self, transition):
        idx = self.data_idx % self.max_size
        self.data['state'][idx] = transition['state']
        self.data['action'][idx] = transition['action']
        self.data['reward'][idx] = transition['reward']
        self.data['next_state'][idx] = transition['next_state']
        self.data['done'][idx] = float(transition['done'])
        if self.action_prob_exist:
            self.data['log_prob'][idx] = transition['log_prob']

        self.data_idx += 1

    def sample(self, shuffle, batch_size=None):
        if shuffle:
            sample_num = min(self.max_size, self.data_idx)
            rand_idx = np.random.choice(sample_num, batch_size, replace=False)
            sampled_data = {}
            sampled_data['state'] = self.data['state'][rand_idx]
            sampled_data['action'] = self.data['action'][rand_idx]
            sampled_data['reward'] = self.data['reward'][rand_idx]
            sampled_data['next_state'] = self.data['next_state'][rand_idx]
            sampled_data['done'] = self.data['done'][rand_idx]
            if self.action_prob_exist:
                sampled_data['log_prob'] = self.data['log_prob'][rand_idx]
            return sampled_data
        else:
            return self.data

    def size(self):
        return min(self.max_size, self.data_idx)


class Pool():
    def __init__(self, max_size, state_dim):
        self.max_size = max_size
        self.data_idx = 0
        self.data = np.zeros((self.max_size, state_dim))

    def put_data(self, state_):
        idx = self.data_idx % self.max_size
        self.data[idx] = state_
        self.data_idx += 1

    def sample(self, shuffle, batch_size=None):
        sample_num = min(self.max_size, self.data_idx)
        rand_idx = np.random.choice(sample_num, batch_size, replace=False)
        sampled_data = self.data[rand_idx]
        return sampled_data

    def size(self):
        return min(self.max_size, self.data_idx)


def make_mini_batch(*value):
    mini_batch_size = value[0]
    full_batch_size = len(value[1])
    full_indices = np.arange(full_batch_size)
    np.random.shuffle(full_indices)
    for i in range(full_batch_size // mini_batch_size):
        indices = full_indices[mini_batch_size * i: mini_batch_size * (i + 1)]
        yield [x[indices] for x in value[1:]]


def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


class BNNMLP(nn.Module):
    def __init__(self, n_in, n_out, W_mu=None, b_mu=None, W_std=None,
                 b_std=None, scaled_variance=False, prior_per='parameter'):
        """Initialization.
        Args:
            n_in: int, the size of the input data.
            n_out: int, the size of the output.
            W_std: float, the initial value of
                the standard deviation of the weights.
            b_std: float, the initial value of
                the standard deviation of the biases.
            prior_per: str, indicates whether using different prior for
                each parameter, option `parameter`, or use the share the
                prior for all parameters in the same layer, option `layer`.
        """
        super(BNNMLP, self).__init__()

        self.n_in = n_in
        self.n_out = n_out
        self.scaled_variance = scaled_variance

        if W_mu is None:
            if self.scaled_variance:
                W_mu = 1.
            else:
                W_mu = 1. / math.sqrt(self.n_in)
        if b_mu is None:
            b_mu = 1.

        if W_std is None:
            if self.scaled_variance:
                W_std = 1.
            else:
                W_std = 1. / math.sqrt(self.n_in)
        if b_std is None:
            b_std = 1.

        if prior_per == "layer":
            W_shape, b_shape = (1), (1)
        elif prior_per == "parameter":
            W_shape, b_shape = [self.n_in, self.n_out], [self.n_out]
        else:
            raise ValueError("Accepted values: `parameter` or `layer`")

        # intial posterior
        bound = 1. / math.sqrt(self.n_in)
        m = torch.distributions.uniform.Uniform(torch.tensor([-bound]), torch.tensor([bound]))
        W_mu_tmp = m.sample(W_shape).squeeze()
        b_mu_tmp = m.sample(b_shape).squeeze()

        # W_mu_tmp = torch.zeros(W_shape)
        # b_mu_tmp = torch.zeros(b_shape)

        self.W_mu = nn.Parameter(
            W_mu_tmp, requires_grad=True)
        self.b_mu = nn.Parameter(
            b_mu_tmp, requires_grad=True)

        self.W_std = nn.Parameter(
            torch.zeros(W_shape) + 0.01, requires_grad=True)
        self.b_std = nn.Parameter(
            torch.zeros(b_shape) + 0.01, requires_grad=True)

        # save prior
        self.W_mu_prior = torch.zeros(W_shape)
        self.b_mu_prior = torch.zeros(b_shape)
        self.W_std_prior = torch.ones(W_shape)
        self.b_std_prior = torch.ones(b_shape)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
        Returns:
            output: torch.tensor, [batch_size, output_dim], the output data.
        """

        W = self.W_mu + self.W_std * \
            torch.randn((self.n_in, self.n_out), device=self.W_std.device)
        if self.scaled_variance:
            W = W / math.sqrt(self.n_in)
        b = self.b_mu + self.b_std * \
            torch.randn((self.n_out), device=self.b_std.device)

        output = torch.mm(X, W) + b

        return output

    def sample_predict(self, X, n_samples):
        """Makes predictions using a set of sampled weights.
        Args:
            X: torch.tensor, [n_samples, batch_size, input_dim], the input
                data.
            n_samples: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [n_samples, batch_size, output_dim], the output data.
        """
        X = X.float()

        Ws = self.W_mu +self.W_std * \
             torch.randn([n_samples, self.n_in, self.n_out],
                         device=self.W_std.device)
        if self.scaled_variance:
            Ws = Ws / math.sqrt(self.n_in)
        bs = self.b_mu + self.b_std * \
             torch.randn([n_samples, 1, self.n_out],
                         device=self.b_std.device)

        return torch.matmul(X, Ws) + bs


class BNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 W_mu=None, b_mu=None, W_std=None, b_std=None, scaled_variance=False, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
            W_std: float, the initial value of the logarithm of
                the standard deviation of the weights.
            b_std: float, the initial value of the logarithm of
                the standard deviation of the biases.
        """
        super(BNN, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        # # Initialize layers
        self.input_layer = nn.Linear(input_dim, hidden_dims[0])

        # self.input_layer = BNNMLP(
        #     input_dim, hidden_dims[0], W_mu, b_mu, W_std, b_std,
        #     scaled_variance=scaled_variance)

        self.norm_layer1 = init_norm_layer(hidden_dims[0], self.norm_layer)

        self.mid_layer = BNNMLP(
            hidden_dims[0], hidden_dims[1], W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)

        self.norm_layer2 = init_norm_layer(hidden_dims[1], self.norm_layer)

        self.output_layer = BNNMLP(
            hidden_dims[1], output_dim, W_mu, b_mu, W_std, b_std,
            scaled_variance=scaled_variance)

        # self.final_output_layer = nn.Softmax(dim=-1)

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            sample: boolean, whether or not perform forward pass using
                sampled weights.
        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer(X)))
        X = self.output_layer(X).squeeze(0)
        # X = self.final_output_layer(X)

        return X

    def sample_functions(self, X, n_samples):
        """Performs predictions using `n_samples` set of weights.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.
            n_samples: int, the number of weight samples used to make
                predictions.
        Returns:
            torch.tensor, [batch_size, n_samples, output_dim], the output
            data.
        """
        X = X.view(-1, self.input_dim)
        X = torch.unsqueeze(X, 0).repeat([n_samples, 1, 1])

        # X = self.activation_fn(self.input_layer(X))
        # X = self.activation_fn(self.mid_layer.sample_predict(X, n_samples))
        # X = self.output_layer.sample_predict(X, n_samples)
        X = self.activation_fn(self.norm_layer1(self.input_layer(X)))
        X = self.activation_fn(self.norm_layer2(self.mid_layer.sample_predict(X, n_samples)))
        X = self.output_layer.sample_predict(X, n_samples)

        return X

class BNNActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, activation_fn, has_continuous_action_space, device, norm_layer=None):
        super(BNNActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        self.device = device
        self.norm_layer = norm_layer

        # actor
        self.action_dim = action_dim
        self.action_logstd = nn.Parameter(-1.0 * torch.ones(action_dim))
        self.action_logstd.to(self.device)

        n_units = 64
        n_hidden = 2
        hidden_dims = [n_units] * n_hidden

        self.actor = BNN(state_dim, action_dim, hidden_dims, activation_fn, has_continuous_action_space,
                         scaled_variance=True).to(self.device)
        self.actor_prior = BNN(state_dim, action_dim, hidden_dims, activation_fn, has_continuous_action_space,
                         scaled_variance=True).to(self.device)
        self.actor_prior.eval()

        # critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, n_units),
            nn.Tanh(),
            nn.Linear(n_units, n_units),
            nn.Tanh(),
            nn.Linear(n_units, 1)
        )

    def forward(self):
        raise NotImplementedError

    def get_action(self, state):
        action_mean = self.actor(state)
        action_std = torch.exp(self.action_logstd)

        return action_mean, action_std

    def select_action(self, state):
        with torch.no_grad():
            action_mean = self.actor(state)
            action_std = torch.exp(self.action_logstd)

        return action_mean, action_std

    def v(self, state):
        return self.critic(state)


class fkPPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic,
                 gamma, K_epochs, eps_clip, has_continuous_action_space, max_ep_len,
                 dist_coeff=1.0, prior_coeff=1.0, n_samples=10, activation_fn='tanh', device='cpu'):

        self.has_continuous_action_space = has_continuous_action_space

        self.gamma = gamma
        self.gae = False
        self.lambda_ = 0.95
        self.entropy_coef = 0.01

        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.batch_size = 64
        self.critic_coef = 0.5
        self.max_grad_norm = 0.5

        self.n_samples = n_samples
        self.m_samples = n_samples

        self.device = device

        self.dist_coeff = dist_coeff
        self.prior_coeff = prior_coeff

        self.activation_fn = activation_fn
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_ep_len = max_ep_len

        self.buffer = RolloutBuffer()
        # self.data = ReplayBuffer(action_prob_exist=True, max_size=max_ep_len, state_dim=state_dim,
        #                          num_action=action_dim)

        self.pool = Pool(max_size=max_ep_len*5, state_dim=state_dim)

        self.policy = BNNActorCritic(state_dim, action_dim, activation_fn, has_continuous_action_space,
                                     device=self.device).to(self.device)
        self.actor_optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor}
            # ,{'params': self.policy.action_logstd, 'lr': lr_actor}
        ])
        # self.action_logstd_optimizer = torch.optim.Adam([self.policy.action_logstd], lr=lr_actor)
        self.critic_optimizer = torch.optim.Adam([{'params': self.policy.critic.parameters(), 'lr': lr_critic}])

        self.critic_loss = nn.SmoothL1Loss()

        self.policy_old = BNNActorCritic(state_dim, action_dim, activation_fn, has_continuous_action_space,
                                         device=self.device).to(self.device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.policy_old.eval()

        # self.injected_noise = 0.001
        self.eta = 0.1  # 0.1
        self.n_eigen_threshold = 2

    def put_data_pool(self, state):
        self.pool.put_data(state)

    def update(self, time_step, writer):

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(self.device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(self.device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(self.device)
        old_rewards = torch.squeeze(torch.stack(self.buffer.rewards, dim=0)).detach().to(self.device)
        # old_nextstates = torch.squeeze(torch.stack(self.buffer.nextstates, dim=0)).detach().to(self.device)
        old_is_terminals = torch.squeeze(torch.stack(self.buffer.is_terminals, dim=0)).detach().to(self.device)
        num_buffer = len(old_states)

        with torch.no_grad():
            old_values = self.policy.v(old_states).float().squeeze().to(self.device)
            next_value = self.policy.v(self.buffer.nextstates[-1]).to(self.device)
            if self.gae:
                advantages = torch.zeros_like(old_rewards).to(self.device)
                lastgaelam = 0
                for t in reversed(range(num_buffer)):
                    if t == num_buffer - 1:
                        nextnonterminal = 1.0 - torch.tensor([self.buffer.is_terminals[-1]]).to(self.device)
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - old_is_terminals[t + 1]
                        nextvalues = old_values[t + 1]
                    delta = old_rewards[t] + self.gamma * nextvalues * nextnonterminal - old_values[t]
                    advantages[t] = lastgaelam = delta + self.gamma * self.lambda_ * nextnonterminal * lastgaelam
                returns = advantages + old_values
            else:
                returns = torch.zeros_like(old_rewards).to(self.device)
                for t in reversed(range(num_buffer)):
                    if t == num_buffer - 1:
                        nextnonterminal = 1.0 - torch.tensor([self.buffer.is_terminals[-1]]).to(self.device)
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - old_is_terminals[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = old_rewards[t] + self.gamma * nextnonterminal * next_return
                advantages = returns - old_values

        #
        for i in range(self.K_epochs):
            # end_time = datetime.now()

            curr_mu, curr_sigma = self.policy.get_action(old_states)
            curr_dist = torch.distributions.Normal(curr_mu.squeeze(), curr_sigma.squeeze())
            entropy = curr_dist.entropy().sum(-1)
            curr_log_prob = curr_dist.log_prob(old_actions).sum(1, keepdim=False)
            new_values = self.policy.v(old_states).squeeze()

            # policy clipping
            ratio = torch.exp(curr_log_prob - old_logprobs.detach())
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # Evaluating old actions and values, and evaluate the kl divergence between posetior and prior

            # start_time = datetime.now()
            # print("surrogate term using time  : ", start_time - end_time)

            # Draw a measurement subset from a pool

            # sample from old using geometric distribution
            p = 2.0 / num_buffer
            g = torch.distributions.Geometric(p)
            num_geo = g.sample()
            count = 0
            while num_geo >= num_buffer:
                num_geo = g.sample()
                count += 1
                if count > 10:
                    num_geo = num_buffer -1
                    break

            num_geo = min(int(num_geo), 500)
            num_geo = max(num_geo, 2)

            # sample measurement set with size n
            perm = torch.randperm(int(num_buffer))
            idx_geo = perm[:num_geo]
            states_mset = old_states[idx_geo, :]

            # print("num_buffer =  ", num_buffer, " num_geo = ", num_geo)
            # states_mset = torch.from_numpy(self.pool.sample(shuffle=True, batch_size=num_geo)).float().to(self.device)

            # Draw functions from prior policy, old policy, and current policy
            # size of nnet_samples is [n_samples, batch_size, output_dim/action_dim]
            noisy_func_x_rand = self.policy.actor.sample_functions(states_mset, self.n_samples)
            noisy_func_x_rand_prior = self.policy.actor_prior.sample_functions(states_mset, self.n_samples)
            noisy_func_x_rand_old = self.policy_old.actor.sample_functions(states_mset, self.n_samples)

            # # estimate entropy surrogate H(q(f))
            estimator = SpectralScoreEstimator(eta=self.eta, n_eigen_threshold=self.n_eigen_threshold)
            dlog_q = estimator.compute_gradients(noisy_func_x_rand)
            entropy_sur = torch.mean(torch.sum(-dlog_q.detach() * noisy_func_x_rand, -1))

            # estimate cross entropy with prior
            cross_entropy_gradients = estimator.compute_gradients(noisy_func_x_rand_prior, noisy_func_x_rand)
            cross_entropy_sur = torch.mean(torch.sum(cross_entropy_gradients.detach() * noisy_func_x_rand, -1))

            fk_prior = -entropy_sur - cross_entropy_sur

            # estimate cross entropy with old policy
            cross_entropy_gradients_old = estimator.compute_gradients(noisy_func_x_rand_old, noisy_func_x_rand)
            cross_entropy_sur_old = torch.mean(torch.sum(cross_entropy_gradients_old.detach() * noisy_func_x_rand, -1))

            fk_old = -entropy_sur - cross_entropy_sur_old

            # fk_prior = 0
            # fk_old = 0

            #
            actor_loss = (-torch.min(surr1, surr2) - self.entropy_coef * entropy).mean()
            actor_loss = actor_loss + self.prior_coeff * fk_prior + self.dist_coeff * fk_old

            # print("other = ", (-torch.min(surr1,surr2) - entropy).mean())

            # critic_loss = self.critic_coef * ((new_values - returns).pow(2).mean())
            critic_loss = self.critic_coef * self.critic_loss(new_values, returns)

            writer.add_scalar("actor_loss", actor_loss, time_step)
            writer.add_scalar("critic_loss", critic_loss, time_step)
            # writer.add_scalar("tmp_loss", tmp_loss, time_step)
            writer.add_scalar("fk_prior", fk_prior, time_step)
            writer.add_scalar("fk_old", fk_old, time_step)
            writer.add_scalar("action_logstd", torch.exp(self.policy.action_logstd).mean(), time_step)
            writer.flush()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
            self.actor_optimizer.step()

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
            self.critic_optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()

    def save(self, checkpoint_path):
        torch.save(self.policy.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
        
       


