from typing import List, Literal, Tuple

import torch
from torchtyping import TensorType
from src.gfn.containers import Trajectories
from src.gfn.estimators import LogStateFlowEstimator,LogZEstimator,LogitPBEstimator
from src.gfn.losses.base import PFBasedParametrization, TrajectoryDecomposableLoss
from src.gfn.samplers import BackwardDiscreteActionsSampler
from src.gfn.containers.states import States
# Typing
ScoresTensor = TensorType[-1, float]
LossTensor = TensorType[0, float]
LogPTrajectoriesTensor = TensorType["max_length", "n_trajectories", float]


class RLParametrization(PFBasedParametrization):
    r"""
    Exactly the same as DBParametrization
    """
    def __init__(self, logit_PF,logit_PB, logZ: LogZEstimator):
        self.logZ=logZ
        super().__init__(logit_PF,logit_PB)

class TrajectoryRL(TrajectoryDecomposableLoss):
    def __init__(
        self,
        parametrization: RLParametrization,
        optimizer: Tuple[torch.optim.Optimizer,
                         torch.optim.Optimizer,
                         torch.optim.Optimizer | None,
                         torch.optim.Optimizer | None],
        logV: LogStateFlowEstimator,
        logVB: LogStateFlowEstimator,
        logit_PG:LogitPBEstimator=None,
        log_reward_clip_min: float = -12,
        gamma: float = 1.0,
    ):
        """
        :param lamda: parameter for geometric weighing
        """
        # Lamda is a discount factor for longer trajectories. The part of the loss
        # corresponding to sub-trajectories of length i is multiplied by lamda^i
        # where an edge is of length 1. As lamda approaches 1, each loss becomes equally weighted.
        self.gamma = gamma
        self.logV=logV
        self.logVB=logVB
        self.logit_PG=logit_PG
        self.log_reward_clip_min=log_reward_clip_min
        self.A_optimizer,self.V_optimizer,\
        self.B_optimizer,self.VB_optimizer=optimizer
        super().__init__(parametrization,fill_value=0.)

    def guide_log_prob(self,states: States,env):
       # logits = self.logit_PG(states)
        with torch.no_grad():
            logits=self.parametrization.logit_PF(states)
        if torch.any(torch.all(torch.isnan(logits), 1)): raise ValueError("NaNs in estimator")
        if isinstance(self.logit_PG,LogitPBEstimator):
            logits[~states.backward_masks] = self.inf_value
        else:
            logits[~states.forward_masks ] = self.inf_value
        low_states_index=~(env.log_reward(states) > torch.tensor([env.R0]).log())
        logits[low_states_index,-1]=self.inf_value
        log_all = logits.log_softmax(dim=-1)
        return log_all
    def get_pgs(
            self,
            trajectories: Trajectories) -> Tuple[LogPTrajectoriesTensor, LogPTrajectoriesTensor]:
        if not(trajectories.is_backward):
            raise ValueError("Forward trajectories are not supported")
        if isinstance(self.logit_PG,LogitPBEstimator):
            non_init_valid_states,valid_actions,inter_index=self.backward_state_actions(trajectories)
            log_pg_trajectories_all = torch.full_like(trajectories.states[:-1].backward_masks,fill_value=self.fill_value, dtype=torch.float)
        else:
            non_init_valid_states, valid_actions, inter_index = self.forward_state_actions(trajectories)            
            log_pg_trajectories_all = torch.full_like(trajectories.states[:-1].forward_masks, fill_value=self.fill_value, dtype=torch.float)

        valid_log_pg_all=self.guide_log_prob(non_init_valid_states,trajectories.env)
        valid_log_pg_actions= self.action_prob_gather(valid_log_pg_all,valid_actions)

        log_pg_trajectories = torch.full_like(trajectories.actions, fill_value=self.fill_value, dtype=torch.float)
        log_pg_trajectories[inter_index] = valid_log_pg_actions
        log_pg_trajectories_all[inter_index] = valid_log_pg_all
        return log_pg_trajectories,log_pg_trajectories_all

    def get_scores(
        self, trajectories: Trajectories,
            log_pf_traj:LogPTrajectoriesTensor,
            log_pb_traj:LogPTrajectoriesTensor,
    ) -> ScoresTensor:

        terminal_index=trajectories.is_terminating_action
        log_pb_traj.T[terminal_index.T]=trajectories.log_rewards.clamp_min(self.log_reward_clip_min)\
                                                -self.parametrization.logZ.tensor
        scores= (log_pf_traj-log_pb_traj)
        return  scores

    def get_value(self,trajectories:Trajectories,backward=False):
        flatten_masks = ~trajectories.is_sink_action
        values = torch.full_like(trajectories.actions,dtype=torch.float, fill_value=self.fill_value)
        valid_states = trajectories.states[:-1][flatten_masks]  # remove the dummpy one extra sink states
        values[flatten_masks] = self.logV(valid_states).squeeze(-1) if not backward \
            else self.logVB(valid_states).squeeze(-1)
        return values


    def surrogate_loss(self,log_pf, log_qf,advantages):
        """define the loss objective for TRPO"""
        # Its value:    adv
        # Its gradient: adv *▽log p  (= adv* (▽p/p)= ad * {▽exp(logp)/exp(logp)} )
        sur_loss=torch.exp(log_pf - log_qf.detach()).mul(advantages)
        return sur_loss

    def update_model(self, trajectories: Trajectories):
        log_pf_traj, _ = self.get_pfs(trajectories)
        log_pb_traj, _ = self.get_pbs(trajectories)
        scores = self.get_scores(trajectories, log_pf_traj, log_pb_traj).detach()
        values = self.get_value(trajectories)
        advantages, Qt = self.estimate_advantages(trajectories, scores, values.detach())
        Z = self.parametrization.logZ.tensor.exp()
        A_loss = self.surrogate_loss(log_pf_traj, log_pf_traj, advantages).sum(0).mean()
        Z_diff = (Z / Z.detach()) * (scores.sum(0).mean())
        V_loss = (Qt - values).pow(2).sum(0).mean()

        if isinstance(self.V_optimizer,torch.optim.LBFGS):
            def closure():
                self.V_optimizer.zero_grad()
                val=self.get_value(trajectories)
                V_loss= (Qt-val).pow(2).sum(0).mean()
                V_loss.backward()
                return V_loss
            self.V_optimizer.step(closure)
        else:
            self.optimizer_step(V_loss, self.V_optimizer)
        self.optimizer_step(A_loss + Z_diff, self.A_optimizer)
        return A_loss + Z_diff

    def B_update_model(self, trajectories: Trajectories):
        log_pb_traj, log_pb_traj_all = self.get_pbs(trajectories)
        if self.logit_PG is not None:
            log_pg_traj, log_pg_traj_all = self.get_pgs(trajectories)
        else:
            log_pg_traj, log_pg_traj_all = self.get_pfs(trajectories)
        scores = (log_pb_traj - log_pg_traj).detach()
        values = self.get_value(trajectories, backward=True)
        advantages, Qt = self.estimate_advantages(trajectories, scores, values.detach())
        A_loss = self.surrogate_loss(log_pb_traj, log_pb_traj, advantages).sum(0).mean()
        V_loss = (Qt - values).pow(2).sum(0).mean()
        # Kl=self.kl_log_prob(log_pb_traj_all,log_pg_traj_all).mean()

        self.optimizer_step(V_loss, self.VB_optimizer)
        self.optimizer_step(A_loss, self.B_optimizer)
        return A_loss  # ,Kl.detach()

    def optimizer_step(self, loss, optimizer):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    def __call__(self, trajectories: Trajectories) -> Tuple[LossTensor, LossTensor]:
        pass

    @staticmethod
    def kl_log_prob(log_prob_q, log_prob_p):
        log_prob_p = log_prob_p.detach()
        kl = (log_prob_p.exp() * (log_prob_p - log_prob_q)).sum(-1)
        return kl
    @staticmethod
    def entropy(log_pf):
        p_log_p = -(log_pf * log_pf.exp()).sum(-1)
        return p_log_p

    def estimate_advantages(self,trajectories:Trajectories,scores,values, gamma=1.0):
        masks = ~trajectories.is_sink_action
        Vt_prev = torch.zeros_like(scores[0],dtype=torch.float)
        adv_prev = torch.zeros_like(scores[0],dtype=torch.float)
        deltas= torch.full_like(scores[0], fill_value=0., dtype=torch.float)
        Qt = torch.full_like(scores, fill_value=0., dtype=torch.float)
        advantages= torch.full_like(scores, fill_value=0., dtype=torch.float)
        for i in reversed(range(scores.size(0))):
            deltas[masks[i]] = scores[i][masks[i]] + gamma * Vt_prev[masks[i]] - values[i][masks[i]]
            if torch.any(torch.isnan(deltas)): raise ValueError("NaN in scores")
            adv_prev[masks[i]] = deltas[masks[i]] + gamma * adv_prev[masks[i]]
            Vt_prev = values[i]
            advantages[i][masks[i]]= adv_prev[masks[i]]
        Qt[masks] = (values + advantages)[masks]
        advantages[masks] = (advantages[masks] - advantages[masks].mean())
        return advantages, Qt