import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

import numpy as np


#Modified from https://raw.githubusercontent.com/nikhilbarhate99/PPO-PyTorch/master/PPO.py
################################## set device ##################################

print("============================================================================================")


# set device to cpu or cuda
device = torch.device('cpu')

if(torch.cuda.is_available()): 
    device = torch.device('cuda:0') 
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")
    
print("============================================================================================")




################################## PPO Policy ##################################


class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
        self.next_states = []
    

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]
        del self.next_states[:]
    
    def remove_n_last(self, n):
        """removes the n-oldest additions, that is at the beginning of the lists"""
        del self.actions[:n]
        del self.states[:n]
        del self.logprobs[:n]
        del self.rewards[:n]
        del self.is_terminals[:n]
        del self.next_states[:n]

    def save(self, dir):
        import pickle
        with open(dir, 'wb') as file:
            pickle.dump(self, file)

    def load(self, dir):
        import pickle
        with open(dir, 'rb') as file:
            buffer = pickle.load(file)
            self.actions = buffer.actions
            self.states = buffer.states
            self.next_states = buffer.next_states
            self.logprobs = buffer.logprobs
            self.rewards = buffer.rewards
            self.is_terminals = buffer.is_terminals

    def sample_episodes(self, n, ep_len, last_n=False):
        n_timesteps = len(self.states)
        n_ep = n_timesteps // ep_len
        #sample episode numbers
        if last_n:
            ep_no = torch.arange(n_ep-1, max(n_ep-1-n, -1), -1)
        else:
            ep_no = torch.randperm(n_ep)[:n]
        n = len(ep_no)#to handle the case when there are less than batch size number of obs
        #conver episode numbers into episode start inds
        ep_start_inds = ep_no * ep_len
        #get the actual inds
        ep_inds = ep_start_inds.unsqueeze(1) + torch.arange(ep_len).repeat(n, 1)
        ep_inds = ep_inds.flatten()
        #fetch the corresponding stuff
        states = [self.states[ii] for ii in ep_inds]
        actions = [self.actions[ii] for ii in ep_inds]
        next_states = [self.next_states[ii] for ii in ep_inds]
        logprobs = [self.logprobs[ii] for ii in ep_inds]
        rewards = [torch.from_numpy(self.rewards[ii]) for ii in ep_inds] #TODO this is awkward
        is_terminals = [torch.from_numpy(self.is_terminals[ii]) for ii in ep_inds]#This might break something
        #cast everything
        states = torch.stack(states, dim=0).to(device)
        actions = torch.stack(actions, dim=0).to(device)
        next_states = torch.stack(next_states, dim=0).to(device)
        logprobs = torch.stack(logprobs, dim=0).to(device)
        rewards = torch.stack(rewards, dim=0).to(device)
        is_terminals = torch.stack(is_terminals, dim=0).to(device)

        # mask for the ones that are "not good", that is, mask for where the agent died and remains dead, but the episodes continue.
        n_arenas = 1
        prev = torch.zeros(n_arenas).to(device)   
        save_for_training = torch.ones_like(is_terminals).to(device)
        for ii, col in enumerate(is_terminals):
            keep_inds = torch.logical_and(prev.bool(), col.bool())
            save_for_training[ii, keep_inds] = 0.
            prev = col
        
        save_for_training = save_for_training.squeeze(1).bool()
        states = states.squeeze(1)
        actions = actions.squeeze(1)
        next_states = next_states.squeeze(1)
        logprobs = logprobs.squeeze(1)
        rewards = rewards.squeeze(1)
        is_terminals = is_terminals.squeeze(1)
        
        #only take those where the agent is alive
        states = states[save_for_training]
        actions = actions[save_for_training]
        logprobs = logprobs[save_for_training]
        next_states = next_states[save_for_training]
        rewards = rewards[save_for_training]
        is_terminals = is_terminals[save_for_training]
        done_inds = torch.where(is_terminals == 1.0)[0] #could be used, but need to check if there is more ones than desired
        assert(done_inds.shape[0] == n)
        return [states, actions, logprobs, rewards, next_states, done_inds]        

class MultiheadNet(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(MultiheadNet, self).__init__()
        self.shared_net = nn.Sequential(
                            nn.Linear(in_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh())
        
        self.module_heads = nn.ModuleList()
        self.out_dim = torch.from_numpy(out_dim)
        for od in self.out_dim:
            self.module_heads.append(nn.Sequential(nn.Linear(64, od), nn.LogSoftmax(-1)))


    def forward(self, x):
        x = self.shared_net(x)
        probs = [mh(x) for mh in self.module_heads]
        try:
            dists = [Categorical(logits = action_probs.clamp(-1e10, 1e10)) for action_probs in probs]
        except:
            print("ERROR HAPPENED; THIS IS DEBUG")
            print("action_probs", probs)
            print("x ", x)
            print("shared net", np.sum([par.mean().detach() for par in self.shared_net.parameters()]))
            for asd, mh in enumerate(self.module_heads):
                print(f"mh{asd}", np.sum([par.mean().detach() for par in mh.parameters()]))

        return dists

    def get_logprobs(self, actions, dists):
        action_logprob = torch.stack([dists[ii].log_prob(actions[:,ii]) for ii in range(len(dists))], dim=1)
        action_logprob = action_logprob.sum(1)
        return action_logprob



class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, action_space, action_std_init):
        super(ActorCritic, self).__init__()

        self.action_space = action_space

        if action_space == "continuous":
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)

        # actor
        if action_space == "continuous" :
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                        )
        elif action_space == "multidiscrete":
            self.actor = MultiheadNet(state_dim, action_dim)
        else:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Softmax(dim=-1)
                        )

        
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.Tanh(),
                        nn.Linear(64, 64),
                        nn.Tanh(),
                        nn.Linear(64, 1)
                    )
        
    def set_action_std(self, new_action_std):

        if self.action_space == "continuous":
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")


    def forward(self):
        raise NotImplementedError
    

    def act(self, state, greedy):

        if self.action_space == "continuous":
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        elif self.action_space == "multidiscrete":
            dists = self.actor(state)
            if not greedy:
                action = [dist.sample() for dist in dists]
            else:
                action = [dist.probs.argmax(1) for dist in dists]
            action = torch.stack(action, dim=1)
            action_logprob = self.actor.get_logprobs(action, dists)
            
            return action.detach(), action_logprob.detach()
            
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        
        return action.detach(), action_logprob.detach()
    

    def evaluate(self, state, action):

        if self.action_space == "continuous":
            action_mean = self.actor(state)
            
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
            
            # For Single Action Environments.
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        elif self.action_space == "multidiscrete":
            dists = self.actor(state)
            action_logprobs = self.actor.get_logprobs(action, dists)
            state_values = self.critic(state)
            #dist_entropy = torch.stack([dist.entropy() for dist in dists]).sum()
            dist_entropy = torch.stack([-(dist.probs * dist.logits).sum(-1).mean() for dist in dists]).sum()
            return action_logprobs, state_values, dist_entropy 

        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy


class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, action_space, action_std_init=0.6, lamb=0.95, bs = 4096, entropy_coef = 0.01):

        self.action_space = action_space

        if action_space == "continuous":
            self.action_std = action_std_init

        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.lambd = lamb
        self.entropy_coef = entropy_coef

        self.bs = bs
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim, action_space, action_std_init).to(device)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': lr_actor},
                        {'params': self.policy.critic.parameters(), 'lr': lr_critic}
                    ])

        self.policy_old = ActorCritic(state_dim, action_dim, action_space, action_std_init).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()


    def set_action_std(self, new_action_std):
        
        if self.action_space == "continuous":
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling PPO::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")


    def decay_action_std(self, action_std_decay_rate, min_action_std):
        print("--------------------------------------------------------------------------------------------")

        if self.action_space == "continuous":
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if (self.action_std <= min_action_std):
                self.action_std = min_action_std
                print("setting actor output action_std to min_action_std : ", self.action_std)
            else:
                print("setting actor output action_std to : ", self.action_std)
            self.set_action_std(self.action_std)

        else:
            print("WARNING : Calling PPO::decay_action_std() on discrete action space policy")

        print("--------------------------------------------------------------------------------------------")


    def select_action(self, state, greedy=False):

        if self.action_space == "continuous":
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob = self.policy_old.act(state)

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)

            return action.detach().cpu().numpy().flatten()
        elif self.action_space == "multidiscrete":
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob = self.policy_old.act(state, greedy)
            # self.buffer.states.append(state)
            # self.buffer.actions.append(action)
            # self.buffer.logprobs.append(action_logprob)
            return action.detach()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)

            return action.item()


    def update(self, empty_buffer=True):

        # Monte Carlo estimate of returns
        rewards = []
        n_arenas = self.buffer.rewards[0].shape
        discounted_reward = np.zeros(n_arenas, dtype=np.float32)
        #calculate GAE-lambda
        advantages = []
        advantage = np.zeros(n_arenas, dtype=np.float32)
        rews = torch.from_numpy(np.stack(self.buffer.rewards, axis=1)).to(device)
        state_buffer = torch.stack(self.buffer.states, axis=1)
        state_buffer = state_buffer.view(-1, self.buffer.states[0].shape[1]).to(device)
        with torch.no_grad():
            value_estimates = self.policy.critic(state_buffer)
            value_estimates = value_estimates.squeeze(1).view(n_arenas[0], -1)

        # Follows closely: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/ppo/ppo.py#L62
        deltas = rews[:, :-1] + self.gamma * value_estimates[:, 1:] - value_estimates[:, :-1]
        deltas = torch.cat((deltas, rews[:, -1:] - value_estimates[:,-1:]), axis=1) #delta at time T
        deltas = [deltas[:,dd].cpu().numpy() for dd in range(deltas.shape[1])]


        buffer_rewards = self.buffer.rewards  
        for delta, reward, is_terminal in zip(reversed(deltas), reversed(buffer_rewards), reversed(self.buffer.is_terminals)):
            discounted_reward[is_terminal] = 0.0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

            advantage[is_terminal] = 0.0
            advantage = delta + (self.gamma * self.lambd * advantage)
            advantages.insert(0, advantage)
        rewards = np.stack(rewards, axis=1)
        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        #rewards = (rewards - rewards.mean(1, keepdim=True)) / (rewards.std(1, keepdim=True) + 1e-7)
        #rewards = rewards.clamp(-5, 5)

        advantages = np.stack(advantages, axis=1)
        advantages = torch.tensor(advantages, dtype=torch.float32).to(device)
        is_terminals = torch.from_numpy(np.stack(self.buffer.is_terminals, axis=0).astype(np.float32)).to(device)

        # mask for the ones that are "not good", that is, mask for where the agent died and remains dead, but the episodes continue.
        prev = torch.zeros(n_arenas).to(device)
        save_for_training = torch.ones_like(is_terminals).to(device)
        for ii, col in enumerate(is_terminals):
            keep_inds = torch.logical_and(prev.bool(), col.bool())
            save_for_training[ii, keep_inds] = 0.
            prev = col

        # convert list to tensor
        old_states = torch.stack(self.buffer.states, dim=0).detach().to(device)
        old_actions = torch.stack(self.buffer.actions, dim=0).detach().to(device)
        old_logprobs = torch.stack(self.buffer.logprobs, dim=0).detach().to(device)

        #n_arenas first
        if len(old_states.shape) > 2:
            old_states = old_states.permute(1,0,2)
            old_actions = old_actions.permute(1,0,2)
            old_logprobs = old_logprobs.permute(1,0)
            save_for_training = save_for_training.permute(1,0)
            #drop everything to batch dimension
            old_states = old_states.flatten(0, 1)
            old_actions = old_actions.flatten(0, 1)
            old_logprobs = old_logprobs.flatten()
            rewards = rewards.flatten()
            advantages = advantages.flatten()
            save_for_training = save_for_training.flatten()
            #only save those where the agent is not dead
            old_states = old_states[save_for_training.bool()]
            old_actions = old_actions[save_for_training.bool()]
            old_logprobs = old_logprobs[save_for_training.bool()]
            rewards = rewards[save_for_training.bool()]
            advantages = advantages[save_for_training.bool()]

        
        # Optimize policy for K epochs
        losses = torch.zeros(self.K_epochs, 9)
        for ll in range(self.K_epochs):
            N_train = old_states.shape[0]
            perm = torch.randperm(N_train)
            n_updates = 0   
            n_clipped = 0
            val_preds = []
            val_reals = []
            for ii in range(0, N_train, self.bs):
                real_bs = min(self.bs, N_train)
                inds = perm[ii:ii+real_bs]
                # if inds.shape[0] != self.bs:
                #     print("skipping batch because not batchsize")
                #     #continue
                states_batch = old_states[inds]
                actions_batch = old_actions[inds]
                logprobs_batch = old_logprobs[inds]
                rewards_batch = rewards[inds]
                advantages_batch = advantages[inds]
            
                # Evaluating old actions and values
                logprobs, state_values, dist_entropy = self.policy.evaluate(states_batch, actions_batch)

                # match state_values tensor dimensions with rewards tensor
                state_values = torch.squeeze(state_values)
            
                # Finding the ratio (pi_theta / pi_theta__old)
                ratios = torch.exp(logprobs - logprobs_batch.detach())

                # Finding Surrogate Loss
                #advantages = rewards_batch - state_values.detach()
                surr1 = ratios * advantages_batch
                n_clipped = ( (ratios < (1-self.eps_clip)) | (ratios > (1+self.eps_clip)) ).float().sum()
                surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages_batch

                # final loss of clipped objective PPO
                loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards_batch.squeeze(0)) - self.entropy_coef*dist_entropy
            
                # take gradient step
                self.optimizer.zero_grad()
                loss.mean().backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 5.0)
                self.optimizer.step()
                #log stuff
                relative_entropy = dist_entropy.mean().detach() / -( 1 / self.policy.actor.out_dim).log().sum()
                kl_div_last = -ratios.log().sum() if ll+1 == self.K_epochs else 0.0
                val_preds.append(state_values)
                val_reals.append(rewards_batch.squeeze(0))
                losses[ll] += torch.tensor([-torch.min(surr1, surr2).mean().detach(),
                    0.5*self.MseLoss(state_values, rewards_batch.squeeze(0)).mean().detach(),
                    -self.entropy_coef*dist_entropy.mean().detach(),
                    loss.mean().detach(),
                    dist_entropy.mean().detach(),
                    n_clipped,
                    relative_entropy,
                    kl_div_last, 
                    0.0
                ])

                n_updates += 1
            
            residual_variance_vf = (torch.cat(val_preds) - torch.cat(val_reals)).var() / torch.cat(val_reals).var()
            losses[ll] /= n_updates
            losses[ll, 5] /= real_bs
            losses[ll, 7] /= real_bs * n_updates
            losses[ll, 8] = residual_variance_vf
            # if ll == 1:
            #     print("VALUE FUNCTION")
            #     print(state_values.mean().detach(), state_values.var().detach() )
            #     print(state_values[torch.randperm(real_bs)[:5]])
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        if empty_buffer:
            self.buffer.clear()
        losses[:, 7] = losses[-1, 8]
        losses[:, 8] = losses[0, 8]
        #advantages calculated
        losses = torch.cat((losses,
                            rewards.mean().cpu() * torch.ones_like(losses[:,:1]),
                            rewards.min().cpu() * torch.ones_like(losses[:,:1]),
                            rewards.max().cpu() * torch.ones_like(losses[:,:1]),
                            value_estimates.mean().cpu() * torch.ones_like(losses[:,:1]),
                            value_estimates.min().cpu() * torch.ones_like(losses[:,:1]),
                            value_estimates.max().cpu() * torch.ones_like(losses[:,:1])), dim=1 )

        return losses.mean(0)
    
    
    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   

    def load(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
        
       

