import numpy as np
import torch
from torch.optim import Adam, AdamW
import gym
import time
import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical

device = torch.device('cpu')

# Cumulative sum function from https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/ppo/core.py
def discount_cumsum(x, discount):
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


def mpi_statistics_scalar(x, with_min_and_max=False):
    x = np.array(x, dtype=np.float32)
    global_sum, global_n = np.sum(x), len(x)
    mean = global_sum / global_n

    global_sum_sq = np.sum((x - mean) ** 2)
    std = np.sqrt(global_sum_sq / global_n)  # compute global std

    if with_min_and_max:
        global_min = np.min(x) if len(x) > 0 else np.inf
        global_max = np.max(x) if len(x) > 0 else -np.inf
        return mean, std, global_min, global_max

    return mean, std

class PPOOptimizedReplayBuffer:
    def __init__(self, observation_size, action_size, max_size):
        self.size = max_size
        self.current_idx = 0
        self.filled = 0

        # Pre-allocate memory and move to the device (likely cuda)
        self.states = np.empty((max_size, observation_size))
        self.actions = np.empty((max_size, action_size))
        self.rewards = np.empty((max_size, 1))
        self.values = np.empty((max_size, 1))
        self.log_prob = np.empty((max_size, 1))
        self.next_values = np.empty((max_size, 1))
        self.terminals = np.empty((max_size, 1))

        self.device = device

    def add(self, state, action, reward, value, log_prob, next_value, terminal):
        self.states[self.current_idx] = state
        self.actions[self.current_idx] = action
        self.rewards[self.current_idx] = reward
        self.values[self.current_idx] = value
        self.log_prob[self.current_idx] = log_prob
        self.next_values[self.current_idx] = next_value
        self.terminals[self.current_idx] = terminal

        self.current_idx = (self.current_idx + 1) % self.size
        self.filled = min(self.size, self.filled + 1)

    def get(self):
        self.current_idx = 0
        self.filled = 0
        return self.states, self.actions, self.rewards, self.values, self.log_prob, self.next_values, self.terminals

    def clear(self):
        self.current_idx = 0
        self.filled = 0
class PPOBuffer:
    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((size, act_dim), dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.lam = gamma, lam
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size

    def store(self, obs, act, rew, val, logprob):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size  # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_buf[self.ptr] = logprob
        self.ptr += 1

    def finish_path(self, last_val=0):
        path_slice = slice(self.path_start_idx, self.ptr)
        rews = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)

        # the next two lines implement GAE-Lambda advantage calculation
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)

        # the next line computes rewards-to-go, to be targets for the value function
        self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]

        self.path_start_idx = self.ptr

    def get(self):
        assert self.ptr == self.max_size  # buffer has to be full before you can get
        self.ptr, self.path_start_idx = 0, 0
        # the next two lines implement the advantage normalization trick
        adv_mean, adv_std = mpi_statistics_scalar(self.adv_buf)
        self.adv_buf = (self.adv_buf - adv_mean) / (adv_std+1e-5)
        data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf,
                    adv=self.adv_buf, logp=self.logp_buf)
        return {k: torch.as_tensor(v, dtype=torch.float32).to(device) for k, v in data.items()}

def build_net(shapes, activation=nn.ReLU, output_activation=nn.Identity, linear=True):
    layers = []
    if linear:
        for j in range(len(shapes) - 1):
            act = activation if j < len(shapes) - 2 else output_activation
            layers += [nn.Linear(shapes[j], shapes[j + 1]), act()]
        return nn.Sequential(*layers)
    else:
        for j in range(len(shapes) - 1):
            act = activation if j < len(shapes) - 2 else output_activation
            layers += [nn.Conv2d(shapes[j], shapes[j + 1], 3, 1, 1), act()]
        return nn.Sequential(*layers)


class Critic(nn.Module):
    def __init__(self, obs_dim, hidden_sizes, activation, linear=True):
        super().__init__()
        self.v_net = build_net([obs_dim] + list(hidden_sizes) + [1], activation, linear=linear)

    def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1)  # Critical to ensure v has right shape.


class CategoricalActor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, linear=True):
        super().__init__()
        self.logits_net = build_net([obs_dim] + list(hidden_sizes) + [act_dim], activation, linear=linear)

    def action_distribution(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)

    def log_prob_action(self, action_distribution, action):
        # Categorical is the action_distribution
        # Should return the log probability of the action given that action is shape (batch_size, 1)
        return action_distribution.log_prob(action)

    def forward(self, obs, act, with_logprob=True):
        action_distribution = self.action_distribution(obs)
        log_prob_action = None
        if with_logprob:
            log_prob_action = self.log_prob_action(action_distribution, act)
        return action_distribution, log_prob_action


class GaussianActor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, linear=True, log_std_scale = -0.5):
        super().__init__()
        log_std = log_std_scale * np.ones(act_dim, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
        self.mu_net = build_net([obs_dim] + list(hidden_sizes) + [act_dim], activation, linear=linear)

    def action_distribution(self, obs):
        mu = self.mu_net(obs)
        std = torch.exp(self.log_std)
        return Normal(mu, std)

    def log_prob_action(self, action_distribution, action):
        return action_distribution.log_prob(action).sum(axis=-1)  # Last axis sum needed for Torch Normal distribution

    def forward(self, obs, act, with_logprob=True):
        action_distribution = self.action_distribution(obs)
        log_prob_action = None
        if with_logprob:
            log_prob_action = self.log_prob_action(action_distribution, act)
        return action_distribution, log_prob_action


class ActorCritic(nn.Module):
    def __init__(self, observation_space, action_space,
                 hidden_sizes=(256, 256), activation=nn.Tanh,
                 linear=True, continuous=True, log_std_scale=-0.5):
        super().__init__()

        # Save the internal variables such as observation and action space sizes
        self.obs_dim = observation_space
        self.act_dim = action_space
        self.is_continuous = continuous
        self.is_linear = linear
        self.activation = activation
        self.log_std_scale = log_std_scale

        # Build policy function
        if continuous:
            self.policy = GaussianActor(self.obs_dim, self.act_dim, hidden_sizes, self.activation,
                                        linear=self.is_linear, log_std_scale=self.log_std_scale).to(device)
        else:
            self.policy = CategoricalActor(self.obs_dim, self.act_dim, hidden_sizes, self.activation,
                                           linear=self.is_linear).to(device)

        # build value function
        self.value = Critic(self.obs_dim, hidden_sizes, activation).to(device)

    def act(self, obs, with_log_prob=False):
        with torch.no_grad():
            action_distribution = self.policy.action_distribution(obs)
            action_sampled = action_distribution.sample()
            log_prob_action = self.policy.log_prob_action(action_distribution, action_sampled)
            value_act = self.value(obs)
        if with_log_prob:
            return action_sampled.cpu().numpy(), value_act.cpu().numpy(), log_prob_action.cpu().numpy()
        return action_sampled.cpu().numpy()

class PPO:
    def __init__(self, observation_size,action_size, hidden_shape=(256, 256), actor_critic=ActorCritic,
                 continuous=True, seed=0, steps_per_epoch=1000, epochs=50, gamma=0.99, clip_ratio=0.2,
                 pi_lr=3e-4, vf_lr=1e-3, train_pi_iters=80, train_v_iters=80, lam=0.97, max_ep_len=1000,
                 target_kl=0.01, save_freq=10, linear = True, log_std_scale=-0.5, optimizer_name="Adam"):
        self.obs_dim =observation_size
        self.act_dim = action_size
        self.hidden_size = hidden_shape
        self.is_linear = linear
        self.is_continuous = continuous
        self.steps_per_epoch = steps_per_epoch
        self.gamma = gamma
        self.lam = lam
        self.clip_ratio = clip_ratio
        self.policy_lr = pi_lr
        self.value_lr = vf_lr
        self.train_policy_epochs = train_pi_iters
        self.train_value_epochs = train_v_iters
        self.max_ep_len = max_ep_len
        self.target_KL = target_kl
        self.log_std_scale = log_std_scale


        self.AC = actor_critic(observation_space=self.obs_dim, action_space=self.act_dim,
                               hidden_sizes=self.hidden_size, continuous=self.is_continuous,
                               linear=self.is_linear, log_std_scale=self.log_std_scale).to(device)

        self.buffer = PPOBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=self.steps_per_epoch, gamma=0.99, lam=0.95)

        # Set up optimizers for policy and value function
        if optimizer_name == "Adam":
            self.policy_optimizer = Adam(self.AC.policy.parameters(), lr=self.policy_lr)
            self.value_optimizer = Adam(self.AC.value.parameters(), lr=self.value_lr)
        elif optimizer_name == "AdamW":
            self.policy_optimizer = AdamW(self.AC.policy.parameters(), lr=self.policy_lr, weight_decay=0.1)
            self.value_optimizer = AdamW(self.AC.value.parameters(), lr=self.value_lr, weight_decay=0.1)

    # Set up function for computing PPO policy loss
    def compute_loss_policy(self, data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']

        # Policy loss
        pi, logp = self.AC.policy(obs, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1 + self.clip_ratio) | ratio.lt(1 - self.clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, pi_info

    # Set up function for computing value loss
    def compute_loss_value(self, data):
        obs, ret = data['obs'], data['ret']
        return ((self.AC.value(obs) - ret) ** 2).mean()

    def pretrain(self, expert_traj, verbose=False, steps = 100000, action_sum = 0, action_mul = 1, batch_size = 32):
        #expert_traj is a numpy array sized 1000000,5, where each row is [state, action, reward, next_state, done] which are all flattened arrays
        loss_func = nn.MSELoss()
        mean_loss_actor = 0
        mean_loss_critic = 0
        for _ in range(steps):
            ind = np.random.randint(0, expert_traj.shape[0], size=batch_size)
            state_expert, action_expert, reward_expert = expert_traj[ind,0], expert_traj[ind,1], expert_traj[ind,2]
            state_expert = np.reshape(np.concatenate(state_expert), (batch_size, -1))
            action_expert = np.reshape(np.concatenate(action_expert), (batch_size, -1))
            reward_expert = np.asarray(reward_expert, dtype=np.float32).reshape((batch_size, -1))
            state = torch.FloatTensor(state_expert).to(device)
            env_action, value_old, log_prob_old = self.AC.act(state, with_log_prob=True)        
            env_action_scaled = (env_action+action_sum)*action_mul
            env_action, value_old, log_prob_old = torch.FloatTensor(env_action).to(device), torch.FloatTensor(value_old).to(device), torch.FloatTensor(log_prob_old).to(device)
            loss_act = 0
            for i in range(self.train_policy_epochs):
                self.policy_optimizer.zero_grad()
                pi, logp = self.AC.policy(state, env_action)
                ratio = torch.exp(logp - log_prob_old)
                clip_adv = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * value_old
                loss_pi = -(torch.min(ratio * value_old, clip_adv)).mean()

                approx_kl = (log_prob_old - logp).mean().item()
                ent = pi.entropy().mean().item()
                clipped = ratio.gt(1 + self.clip_ratio) | ratio.lt(1 - self.clip_ratio)
                clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
                pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

                kl = np.sum(pi_info['kl'])
                if kl > 1.5 * self.target_KL:
                    break
                
                loss_pi.backward()
                loss_act += loss_pi.item()
            
            loss_value = 0
            for i in range(self.train_value_epochs):
                self.value_optimizer.zero_grad()
                loss_v = ((self.AC.value(state) - torch.FloatTensor(reward_expert).to(device)) ** 2).mean()
                loss_v.backward()
                self.value_optimizer.step()
                loss_value += loss_v.item()
            
            mean_loss_actor = (mean_loss_actor*_+loss_act)/(1+_)
            mean_loss_critic = (mean_loss_critic*_+loss_value)/(1+_)
            if verbose and _%10==0:
                print("Pretrain Step: {}, Actor loss: {}, Critic loss: {}".format(_, mean_loss_actor, mean_loss_critic))
            


    def update(self):
        data = self.buffer.get()

        pi_l_old, pi_info_old = self.compute_loss_policy(data)
        pi_l_old = pi_l_old.item()
        v_l_old = self.compute_loss_value(data).item()

        # Train policy with multiple steps of gradient descent
        for i in range(self.train_policy_epochs):
            self.policy_optimizer.zero_grad()
            loss_pi, pi_info = self.compute_loss_policy(data)
            kl = np.sum(pi_info['kl'])
            if kl > 1.5 * self.target_KL:
                break
            loss_pi.backward()
            self.policy_optimizer.step()

        # Value function learning
        for i in range(self.train_value_epochs):
            self.value_optimizer.zero_grad()
            loss_v = self.compute_loss_value(data)
            loss_v.backward()
            self.value_optimizer.step()

        # Log changes from update
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        return pi_l_old, v_l_old, kl, ent, cf, loss_pi.item(), loss_v.item()

    def get_action(self, obs, with_log_prob=False):
        return self.AC.act(torch.FloatTensor(obs).to(device), with_log_prob)
    
    def save(self, fname):
        torch.save(self.AC.state_dict(), fname+'.pth')

    def load(self, fname):
        self.AC.load_state_dict(torch.load(fname + '.pth', map_location=device))
