from typing import Optional, Sequence, List
import numpy as np
import torch
from torch import autograd
import torch.nn.functional as F

from networks.policies import MLPNestedPolicyPG
from networks.gcn_policy import GCNNestedPolicyPG
from networks.critics import ValueCritic
from infrastructure import pytorch_util as ptu
from torch import nn

from torch.optim.lr_scheduler import CosineAnnealingLR

class MultiDimPGAgent(nn.Module):
    def __init__(
        self,
        ob_feature_dim: int,
        ac_dim_list: List[int],
        discrete: bool,
        embed_dim: int, 
        n_gcn_layers: int,
        n_layers: int,
        layer_size: int,
        gamma: float,
        learning_rate: float,
        eps_clip: float, 
        K_epochs: int,
        use_baseline: bool,
        use_reward_to_go: bool,
        baseline_learning_rate: Optional[float],
        baseline_gradient_steps: Optional[int],
        gae_lambda: Optional[float],
        normalize_advantages: bool,
    ):
        super().__init__()

        # create the actor (policy) network
        self.actor = GCNNestedPolicyPG(
            ac_dim_list, ob_feature_dim, embed_dim, discrete, n_gcn_layers, n_layers, layer_size
        )

        # create the critic (baseline) network, if needed
        if use_baseline:
            self.critic = ValueCritic(
                ob_feature_dim, embed_dim, n_gcn_layers, n_layers, layer_size, baseline_learning_rate
            )
            self.baseline_gradient_steps = baseline_gradient_steps
        else:
            self.critic = None
            print("NO BASELINE=====================")
        

        # old actor and critic
        self.old_actor = GCNNestedPolicyPG(
           ac_dim_list, ob_feature_dim, embed_dim, discrete, n_gcn_layers, n_layers, layer_size
        )
        self.old_actor.load_state_dict(self.actor.state_dict())

        # else:
        #     self.critic = None
        self.optimizer = torch.optim.Adam(self.actor.parameters(),  learning_rate/10, betas=(0.995, 0.995))
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=1200, eta_min=1e-5)

        # other agent parameters
        self.gamma = gamma
        self.use_reward_to_go = use_reward_to_go
        self.gae_lambda = gae_lambda
        self.normalize_advantages = normalize_advantages
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

    def update(
        self,
        obs: Sequence[np.ndarray],
        actions: Sequence[np.ndarray],
        rewards: Sequence[np.ndarray],
        terminals: Sequence[np.ndarray],
        masks: Sequence[np.ndarray],
        obs_expert: None,
        actions_expert: None, 
        masks_expert: None,
        lambda_0: Optional[float],
        lambda_1: Optional[float], 
        valid: None,
        entropy_weight: float = 5e-5
    ) -> dict:
        """The train step for PG involves updating its actor using the given observations/actions and the calculated
        qvals/advantages that come from the seen rewards.

        Each input is a list of NumPy arrays, where each array corresponds to a single trajectory. The batch size is the
        total number of samples across all trajectories (i.e. the sum of the lengths of all the arrays).
        """
        # torch.autograd.set_detect_anomaly(True)
        # step 1: calculate Q values of each (s_t, a_t) point, using rewards (r_0, ..., r_t, ..., r_T)
        print("entropy update: ", entropy_weight)
        q_values: Sequence[np.ndarray] = self._calculate_q_vals(rewards)

        # flatten the lists of arrays into single arrays, so that the rest of the code can be written in a vectorized
        # way. obs, actions, rewards, terminals, and q_values should all be arrays with a leading dimension of `batch_size`
        # beyond this point.
        # should be a B×N array, where B is the batch size and N is the observation/action dimension, or just B for the scalars (reward, terminal, q)
        q_flat: Sequence[float] = []
        r_flat: Sequence[float] = []
        o_flat: Sequence[float] = []
        t_flat: Sequence[float] = []
        a_flat: Sequence[float] = []
        m_flat: Sequence[float] = []
        
        for (q, r, o, t, a, m) in zip(q_values, rewards, obs, terminals, actions, masks):
            q_flat = q_flat + q
            r_flat = r_flat + list(r)
            t_flat = t_flat + list(t)
            a_flat = a_flat + list(a)
            o_flat = o_flat + list(o)
            m_flat = m_flat + list(m)  

        q_values = np.array(q_flat)
        rewards = np.array(r_flat)
        obs = o_flat
        terminals = np.array(t_flat)
        actions = np.array(a_flat)
        masks = m_flat
        # step 2: calculate advantages from Q values
        advantages: np.ndarray = self._estimate_advantage(
            obs, rewards, q_values, terminals
        )
        info: dict = {}
        advantages = ptu.from_numpy(advantages)
        # Use PPO, this limits the variance that can occur from large policy updates. Empirically
        # performs better over trust region policy optimization. Also empirically the clip version
        # works better
        if obs_expert is not None: 
            o_flat_exp: Sequence[float] = []
            a_flat_exp: Sequence[float] = []  
            m_flat_exp: Sequence[float] = []              
            for (o_exp, a_exp, m_exp) in zip(obs_expert, actions_expert, masks_expert):
                a_flat_exp = a_flat_exp + list(a_exp)
                [o_flat_exp.append(o_exp[i]) for i in range(len(o_exp))]
                [m_flat_exp.append(m_exp[i]) for i in range(len(m_exp))]
            obs_exp = o_flat_exp
            actions_exp = np.array(a_flat_exp)
            actions_exp = ptu.from_numpy(actions_exp)
            masks_exp = m_flat_exp

            ac0 = actions_exp[:,0].to(torch.long).squeeze()
            ac1 = actions_exp[:,1].to(torch.long).squeeze()
            ac2 = actions_exp[:,2].to(torch.long).squeeze()
            ac3 = actions_exp[:,3].to(torch.long).squeeze()
            ac4 = actions_exp[:,4].to(torch.long).squeeze()
        

        old_logprobs, old_entropy = self.old_actor.evaluate(obs, actions, masks)
        for k in range(self.K_epochs):
            # Evaluating old actions and values
            import time
            eval_start = time.time()
            logprobs, entropy = self.actor.evaluate(obs, actions, masks, r_flat, valid, verbose=False)     
            eval_end = time.time()       
            ratios = torch.exp(logprobs - old_logprobs.detach())
                
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            # final loss of clipped objective PPO
            num_clamped = 0
            num_pass = 0
            for (s1, s2, rat) in zip(surr1, surr2, ratios):
                if s2 < s1: 
                    num_clamped +=1
                else: 
                    num_pass += 1
            print("*******NUM CLAMPED ", num_clamped)
            print("*******NUM PASSED ", num_pass)
            
            loss = -(torch.min(surr1, surr2).mean())-entropy_weight*entropy.mean()

            debugging_dict = {f"surrogate{k}": -torch.min(surr1, surr2).mean(),
                f"entropy{k}":  entropy.mean(),
                f"avg_ratio{k}": torch.mean(ratios)}
            
            critic_info = {}
            expert_start = time.time()
            if obs_expert is not None: 
                logit_exp= self.actor(obs_exp)
                ac_prob0 = F.softmax(logit_exp[0], dim=-1) 
                ac_prob1 = F.softmax(logit_exp[1], dim=-1) 
                ac_prob2 = F.softmax(logit_exp[2], dim=-1) 
                ac_prob3 = F.softmax(logit_exp[3], dim=-1) 
                ac_prob4 = F.softmax(logit_exp[4], dim=-1)

                log_prob_ac0 = -torch.mean(torch.log(ac_prob0[np.arange(len(ac0)),ac0]))
                log_prob_ac1 = -torch.mean(torch.log(ac_prob1[np.arange(len(ac1)),ac1]))
                log_prob_ac2 = -torch.mean(torch.log(ac_prob2[np.arange(len(ac2)),ac2]))
                log_prob_ac3 = -torch.mean(torch.log(ac_prob3[np.arange(len(ac3)),ac3]))
                log_prob_ac4 = -torch.mean(torch.log(ac_prob4[np.arange(len(ac4)),ac4]))
                # Construct the loss using log probabilities and apply negative sum
                loss_bc = torch.sum(log_prob_ac0 + log_prob_ac1 + log_prob_ac2 + log_prob_ac3 + log_prob_ac4)
                loss = loss + lambda_0*lambda_1*loss_bc
                bc_dict = {"BC": lambda_0*lambda_1*loss_bc.detach()}
                bc_dict[f'loss_bc_ppo{k}'] = loss_bc
                bc_dict[f'log_prob_ppo{k}_ac0'] = log_prob_ac0
                bc_dict[f'log_prob_ppo{k}_ac1'] = log_prob_ac1
                bc_dict[f'log_prob_ppo{k}_ac2'] = log_prob_ac2
                bc_dict[f'log_prob_ppo{k}_ac3'] = log_prob_ac3
                bc_dict[f'log_prob_ppo{k}_ac4'] = log_prob_ac4
                debugging_dict.update(bc_dict)
            info.update(debugging_dict)
            expert_end=time.time()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            
            critic_start=time.time()
            if self.critic is not None: 
                baseline_steps = 0
                critic_infos = []
                for _ in range(4):
                    steps_p_batch =  128 #len(obs)//self.baseline_gradient_steps
                    shuffle = np.random.permutation(len(obs))
                    obs_shuff = [obs[s] for s in shuffle]
                    qval_shuff = q_values[shuffle]

                    while baseline_steps < len(obs):
                        end_step = baseline_steps + steps_p_batch if \
                            baseline_steps + steps_p_batch < len(obs) \
                            else len(obs)
                        if end_step < len(obs):
                            critic_info = self.critic.update(obs_shuff[baseline_steps:end_step], qval_shuff[baseline_steps:end_step])
                            critic_infos.append(critic_info)
                        baseline_steps=end_step
                critic_info = {k: np.mean([info[k] for info in critic_infos]) for k in critic_infos[0]}           
            self.old_actor.load_state_dict(self.actor.state_dict())
            critic_end=time.time()

        self.scheduler.step()
        print(f"Step LR = {self.scheduler.get_last_lr()[0]:.6e}")
        
        if self.critic is not None:
            info.update(critic_info)
        return info
            
    def update_critic(self, obs, rewards): 
        q_values = self._calculate_q_vals(rewards)
        o_flat = []
        q_flat = []
        for (q,o) in zip(q_values, obs):
            q_flat = q_flat + q
            o_flat = o_flat + list(o)
        q_values = np.array(q_flat)
        obs = o_flat
        critic_infos = []
        for _ in range(4):
            baseline_steps = 0
            steps_p_batch = 128 #len(obs)//self.baseline_gradient_steps
            shuffle = np.random.permutation(len(obs))
            obs_shuff = [obs[s] for s in shuffle]
            qval_shuff = q_values[shuffle]
    
            
            while baseline_steps < len(obs):
                end_step = baseline_steps + steps_p_batch if \
                    baseline_steps + steps_p_batch < len(obs) \
                    else len(obs)
                if end_step < len(obs):
                    critic_info = self.critic.update(obs_shuff[baseline_steps:end_step], qval_shuff[baseline_steps:end_step])
                    critic_infos.append(critic_info)
                baseline_steps=end_step
        critic_info = {k: np.mean([info[k] for info in critic_infos]) for k in critic_infos[0]}  
        return critic_info
    
    def _calculate_q_vals(self, rewards: Sequence[np.ndarray]) -> Sequence[np.ndarray]:
        """Monte Carlo estimation of the Q function."""

        if not self.use_reward_to_go:
            # Case 1: in trajectory-based PG, we ignore the timestep and instead use the discounted return for the entire
            # trajectory at each point.
            # In other words: Q(s_t, a_t) = sum_{t'=0}^T gamma^t' r_{t'}
            # TODO: DONE use the helper function self._discounted_return to calculate the Q-values
            q_values: Sequence[np.ndarray] = []
            for r_state in rewards:
                q_values.append(self._discounted_return(r_state))
        else:
            # Case 2: in reward-to-go PG, we only use the rewards after timestep t to estimate the Q-value for (s_t, a_t).
            # In other words: Q(s_t, a_t) = sum_{t'=t}^T gamma^(t'-t) * r_{t'}
            # TODO: DONE use the helper function self._discounted_reward_to_go to calculate the Q-values
            #q_values = None
            q_values: Sequence[np.ndarray] = []
            for r_state in rewards:
                q_values.append(self._discounted_reward_to_go(r_state))

        return q_values

    def _estimate_advantage(
        self,
        obs: np.ndarray,
        rewards: np.ndarray,
        q_values: np.ndarray,
        terminals: np.ndarray,
    ) -> np.ndarray:
        """Computes advantages by (possibly) subtracting a value baseline from the estimated Q-values.

        Operates on flat 1D NumPy arrays.
        """

        if self.critic is None:
            # if no baseline, then what are the advantages?
            # advantages = None
            advantages = q_values
        else:
            # TODO: run the critic and use it as a baseline
            # values = None
            # self comment: implement monte carlo sampling so you average between total rewards that start at time t
            values = ptu.to_numpy(self.critic(obs)) # [sum(q_values)/len(q_values) for i in range(len(q_values))]
            values = np.ndarray.flatten(values)
            assert values.shape == q_values.shape

            if self.gae_lambda is None:
                # TODO: if using a baseline, but not GAE, what are the advantages?
                # advantages = None
                advantages = q_values - values
            else:
                # TODO: implement GAE
                batch_size = len(obs)

                values = np.append(values, [0])
                advantages = np.zeros(batch_size + 1)

                for i in reversed(range(batch_size)):
                    advantages[i] = rewards[i] + self.gamma*values[i+1]-values[i]
                    if not terminals[i]:
                        advantages[i] = advantages[i] + \
                                        self.gamma*self.gae_lambda*advantages[i+1]

                # remove dummy advantage
                advantages = advantages[:-1]

        # TODO: normalize the advantages to have a mean of zero and a standard deviation of one within the batch
        if self.normalize_advantages:
            if np.std(advantages) > 0:
                advantages = (advantages - np.mean(advantages))/np.std(advantages)
            else: 
                advantages = (advantages - np.mean(advantages))

        return advantages

    def _discounted_return(self, rewards: Sequence[float]) -> Sequence[float]:
        """
        Helper function which takes a list of rewards {r_0, r_1, ..., r_t', ... r_T} and returns
        a list where each index t contains sum_{t'=0}^T gamma^t' r_{t'}

        Note that all entries of the output list should be the exact same because each sum is from 0 to T (and doesn't
        involve t)!
        """
        #return None
        entry = sum([pow(self.gamma, i)*r for i, r in enumerate(rewards)])
        return [entry for j in range(len(rewards))]


    def _discounted_reward_to_go(self, rewards: Sequence[float]) -> Sequence[float]:
        """
        Helper function which takes a list of rewards {r_0, r_1, ..., r_t', ... r_T} and returns a list where the entry
        in each index t' is sum_{t'=t}^T gamma^(t'-t) * r_{t'}.
        """
        #return None
        rtg = []
        for j in range(len(rewards)):
            gammas = [pow(self.gamma, i) for i in range(len(rewards)-j)]
            entry = sum([g*r for g, r in zip(gammas, rewards[j:])])
            rtg.append(entry)
        return rtg
