from math import acos
from numpy.random import pareto
import torch
import numpy as np
from torch import optim
from torch.optim import optimizer
from base import BaseAlgo 
import utils 


class Agent(BaseAlgo):
    def __init__(self, envs, 
                    model,
                    reward_dim,
                    pref_dist = None,
                    device=None, 
                    num_frames_per_proc=128, 
                    discount=0.99, 
                    lr=0.001, 
                    gae_lambda=0.95,
                    entropy_coef=0.01, 
                    value_loss_coef=0.5, 
                    max_grad_norm=0.5, 
                    adam_eps=1e-8, 
                    clip_eps=0.2, 
                    epochs=4, 
                    batch_size=256, 
                    preprocess_obss=None,
                    action_dim=1,
                    obs_clip = None, 
                    epsilon=None,
                    adaptive=True):

        super().__init__(envs, 
                        model, 
                        reward_dim,
                        device, 
                        num_frames_per_proc, 
                        discount, 
                        lr, 
                        gae_lambda, 
                        entropy_coef, 
                        value_loss_coef, 
                        max_grad_norm, 
                        preprocess_obss,
                        action_dim,
                        obs_clip)
        
        self.pref_dist = pref_dist
        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.adam_eps = adam_eps
        self.epsilon = epsilon
        self.adaptive = adaptive 
        self.optimizer_actor = torch.optim.Adam(self.model.actor.parameters(), lr, eps=adam_eps)
        self.optimizer_critic = torch.optim.Adam(self.model.critic.parameters(), lr, eps=adam_eps)

    def simplex_proj(self,eps):
        x = eps.detach().numpy()
        y = -np.sort(-x)
        sum = 0
        ind = []
        for j in range(len(x)):
            sum = sum + y[j]
            if y[j] + (1 - sum) / (j + 1) > 0:
                ind.append(j)
            else:
                ind.append(0)
        rho = np.argmax(ind)
        delta = (1 - (y[:rho+1]).sum())/(rho+1)
        proj = np.clip(x + delta, 0, 1)
        return torch.tensor(proj)


    ## Optimize Pareto objective for 2d reward singal, can be done analytically
    def find_preferences_2d(self, *losses):
        assert len(losses) == 2

        grads = []
        for loss in losses:
            self.optimizer_actor.zero_grad()
            self.optimizer_critic.zero_grad()
            grad = torch.autograd.grad(loss, self.model.actor.parameters(), retain_graph=True, create_graph=self.adaptive)[0]
            torch.nn.utils.clip_grad_norm_(self.model.actor.parameters(), self.max_grad_norm)
            grad = torch.flatten(grad)
            grad = torch.squeeze(grad)
            grads.append(grad)

        total_grad = grads[1] - grads[0]
        nom = torch.dot(total_grad, grads[0])
        den = torch.norm(total_grad) ** 2
        eps = nom/(den + self.adam_eps) 
        eps = torch.clamp(eps, 0, 1)
        pareto_loss = eps*grads[0] + (1-eps)*grads[1]
        pareto_loss = torch.norm(pareto_loss) ** 2

        return [1-eps, eps], pareto_loss

    
    def find_preferences(self, losses):
        if len(losses)==2:
            return self.find_preferences_2d(*losses)

        # Bellow does not work for adaptive setting, really slow cuz of the projection 
        grads = []
        for loss in losses:
            grad = torch.autograd.grad(loss, self.model.get_policy_params(), retain_graph=True)[0]
            grad = torch.flatten(grad)
            grad = torch.squeeze(grad)
            grads.append(grad)

        epsilon = torch.ones(self.reward_dim, requires_grad=True)
        opt = torch.optim.SGD([epsilon], lr=0.1)

        done = False
        while not done:
            # Crete loss for Pareto stationary problem
            pareto_loss = torch.zeros_like(grads[0])
            for i,grad in enumerate(grads):
                pareto_loss += epsilon[i]*grad
            pareto_loss = torch.linalg.norm(pareto_loss) ** 2

            opt.zero_grad()
            pareto_loss.backward()
            opt.step()
            epsilon.data = self.simplex_proj(epsilon)

            if pareto_loss < 1e-1:
                done = True
            
        return epsilon.detach(), pareto_loss

    def update_parameters(self, experience):
        """Performs a train step 

        Args:
            experience (DictList): A DictList containing the states, value functions, advantages etc 
        """

        # Init logs 
        log_entropies = []
        log_actor_loss = []
        log_critic_loss = []
        log_loss = []
        log_grad_norm = []
        log_epsilon = []
        log_mu = []

        ## Optimize model params
        for epoch in range(self.epochs):
            for batch in self._get_batches():
                # Get experience data for current batch
                obss = experience[batch].obss.to(self.device)
                actions = experience[batch].actions.to(self.device)
                returns = experience[batch].returns.to(self.device)
                advantages = experience[batch].advantages.to(self.device)
                log_probs_old = experience[batch].log_probs.to(self.device)

                # Get AC model output
                dist, critic_value = self.model(obss)

                # Calculate entropy
                entropy = dist.entropy().mean()

                # Calculate critic loss
                critic_loss = self._critic_loss(critic_value, returns)

                # Get log probs
                log_probs = dist.log_prob(actions)

                # sum probs across last axis if action dim > 1
                if self.action_dim > 1:
                    log_probs = torch.sum(log_probs, dim=-1)

                # Calculate actor loss(es)
                losses = []
                for i in range(self.reward_dim):
                    losses.append(self._actor_loss(log_probs, log_probs_old, advantages[:,i]))

                # Check if no epsilon given
                if not self.epsilon:
                    epsilon, pareto_loss = self.find_preferences(losses)
                else:
                    epsilon = torch.tensor(self.epsilon)
                    pareto_loss = torch.tensor(-1)

                actor_loss = 0
                for i in range(self.reward_dim):
                    actor_loss += epsilon[i]*losses[i]

                # Get total loss
                actor_loss = actor_loss - self.entropy_coef * entropy #+ self.value_loss_coef*critic_loss
                
                if self.pref_dist:
                    assert hasattr(self.pref_dist, "sample")
                    sampled_epsilon = self.pref_dist.sample()
                    for e,se in zip(epsilon, sampled_epsilon):
                        actor_loss += (e - se) ** 2

                # Backprop
                self.optimizer_actor.zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.actor.parameters(), self.max_grad_norm)
                self.optimizer_actor.step()

                self.optimizer_critic.zero_grad()
                critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.critic.parameters(), self.max_grad_norm)
                self.optimizer_critic.step()

                grad_norm = sum(p.grad.data.norm(2).item() ** 2 for p in self.model.parameters()) ** 0.5

                # Store logs 
                loss = actor_loss + critic_loss
                log_entropies.append(entropy.item())
                log_actor_loss.append(actor_loss.item())
                log_critic_loss.append(critic_loss.item())
                log_loss.append(loss.item())
                log_grad_norm.append(grad_norm)
                log_epsilon.append(epsilon[0].item())
                log_mu.append(pareto_loss.item())
        # Log vals 
        logs = {
            "entropy": np.mean(log_entropies),
            "critic_loss": np.mean(log_critic_loss),
            "actor_loss": np.mean(log_actor_loss),
            "loss": np.mean(log_loss),
            "epsilon": np.mean(log_epsilon),
            "grad_norm": np.mean(log_grad_norm),
            "mu": np.mean(log_mu)
        }
        return logs




