import dataclasses as dc
from pathlib import Path
import lightning as L
import torch
import numpy as np
from core.reasoning import ReasoningTask, CoT, reasoners

from core.model import LitLLM, ValueModel, Distrubute, make_llm, LLM
from core.utils.th import probabilities, normalize_tensor, entropies, preprocess_logits, fetch
from .base import RLTuning
from .value_iteration import ValueIterArgs, LitValueIteration, multictx_td_error, unictx_td_error
from .utils import estimate_values, estimate_probs

from typing import Literal



@dc.dataclass
class PPOArgs(ValueIterArgs):

    clip: float = 0.1
    kl_coef: float = 0.1
    entropy_coef: float = 0
    gae_lambda: float = 0.98
    actor_weight: float = 1
    critic_weight: float = 1
    norm_adv: Literal['epoch', 'batch'] | None = 'epoch'
    share_adv: bool = False  # tokens in the same step shares advantage
    adv_method: Literal['gae', 'error'] = 'gae'
    use_ref_model: bool = True
    collect_shape: int = 1
    epochs_warmup: int = 0
    epochs_cooldown: int = 0


class LitPPO(LitValueIteration[PPOArgs]):

    _llm_trainable = True
    _require_probs = True

    def __init__(self,
        args: PPOArgs,
        impl: reasoners.ThoughtImpl,
        task: ReasoningTask,
        train_samples: list[CoT] | None,
        checkpoint_dir: Path | str,
        tokenizer_dir: Path | str | None = None,
        reflection: dict | None = None,
        trainer_ckpt_path: Path | str | None = None,
        distribute: Distrubute = 'auto',
        init_vf: Path | str | None = None,
    ):
        super().__init__(
            args, impl, task, train_samples, checkpoint_dir, tokenizer_dir,
            reflection, trainer_ckpt_path, distribute, init_vf,
        )

        # reference model
        if self.args.use_ref_model:
            self.ref_llm = make_llm(
                str(checkpoint_dir),
                tokenizer_dir=tokenizer_dir,
                distribute=distribute,
            )
            self.ref_llm.eval()

    @torch.inference_mode()
    def _process_collected_data(self, _first_loading: bool = False):
        super()._process_collected_data(_first_loading)

        require_advantage = (
            self.args.epochs_warmup <= 
            (self.current_epoch if _first_loading else self.current_epoch + 1) <
            self.args.epochs - self.args.epochs_cooldown
        )
        if not require_advantage:
            return
        
        print("⏳ Preparing policy data", end='\r')

        device = self._collect_device
        data = self._train_collector.data
        n = len(data)
        discount = self.args.discount
        vf = self.vf_target
        
        multi_contextual = self._multi_contextual

        assert n > 0

        # compute the probability of reference model 
        data.ref_probs = estimate_probs(data, self.ref_llm, self.temperature,
                                        self.args.inference_batch_size, device)
        
        # compute value error
        target_method = self.args.target_method
        batch_size = self.args.inference_batch_size
        if target_method == 'td':
            if multi_contextual:
                error = multictx_td_error(vf, data, batch_size, discount, device)
            else:
                error = unictx_td_error(vf, data, batch_size, discount, device)
        elif target_method == 'mc':
            targets = data.vf_targets
            assert isinstance(targets, torch.Tensor)
            v = estimate_values(vf, data, batch_size, device)  # [n, len]
            error = targets.unsqueeze(1) - v
        else:
            assert False
    
        # compute advantage
        adv_method = self.args.adv_method
        if not multi_contextual:
            adv = torch.zeros_like(error)
            # adv[t] = q(x[:t], x[t]) - v(x[:t])
            #        = q(x[:t], x[t]) - v[t-1]
            #        = r[t-1:].sum() - v[t-1]
            #        = target[t-1] - v[t-1]
            #        = error[t-1]
            assert adv.ndim == 2
            if adv_method == 'gae':  # General Advantage Estimation
                coef = (discount * self.args.gae_lambda)
                t0 = min(data.prompt_lengths)
                l = adv.size(1)
                assert t0 > 0
                for i in reversed(range(t0, l)):
                    if i == l - 1: 
                        if i > 0:
                            adv[:, i] = error[:, i-1] + coef * error[:, i]
                        else:
                            adv[:, i] = coef * error[:, i]
                    else:
                        adv[:, i] = error[:, i-1] + coef * adv[:, i + 1]
            elif adv_method == 'error':
                adv[:, 1:] = error[:, :-1]
            else:
                assert False
        elif not self.args.share_adv:
            # adv[i, t] = q(s[i, t], a[i, t]) - v(s[i, t])
            #           = q(s[i, t], a[i, t]) - v(x[i, :t])
            #           = r[i:i+steps].sum() - v[i, t-1]
            #           = target[i] - v[i, t-1]
            #           = error[i, t-1]
            adv = torch.zeros_like(error)
            if adv_method == 'gae':  # General Advantage Estimation
                coef = (discount * self.args.gae_lambda)
                cum_error: float = 0.
                for i in reversed(range(n)):
                    adv[i, 1:] = error[i, :-1]
                    if data.terminated[i] or data.truncated[i]: 
                        cum_error = 0.
                    else:
                        cum_error = cum_error * coef
                        adv[i] += cum_error
                    prompt_length = data.prompt_lengths[i]
                    cum_error = cum_error + float(error[i, prompt_length - 1])
            elif adv_method == 'error':
                adv[:, 1:] = error[:, :-1]
            else:
                assert False
        else:
            adv = torch.zeros(n, dtype=error.dtype, device=error.device)
            if adv_method == 'gae':  # GAE
                coef = (discount * self.args.gae_lambda)
                cum_error: float = 0.
                for i in reversed(range(n)):
                    prompt_length = data.prompt_lengths[i]
                    e = float(error[i, prompt_length-1])
                    if data.terminated[i] or data.truncated[i]: 
                        cum_error = 0.
                    else:
                        cum_error *= coef
                    cum_error += e
                    adv[i] = cum_error
            elif adv_method == 'error':
                adv[:] = error
            else:
                assert False

        if self.args.norm_adv == 'epoch':
            output_mask = data.output_mask
            assert isinstance(output_mask, torch.Tensor)
            if adv.ndim == 2:
                adv = normalize_tensor(adv, output_mask.to(device=adv.device))
            else:
                adv = normalize_tensor(adv)

        data.advantage = adv.to(device=device)
        print("[✔] Preparing policy data")

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx):
        vf_step = super().training_step(batch, batch_idx)

        temperature: float = self.temperature
        tokens: torch.Tensor = batch['tokens']  # (batch_size, len)
        pr_old: torch.Tensor = batch['probs']  # (batch_size, len)
        pr_ref: torch.Tensor | None = batch.get('ref_probs')  # (batch_size, len)
        output_mask: torch.Tensor = batch['output_mask']  # (batch_size, len)
        text_mask: torch.Tensor = batch['text_mask']  # (batch_size, len)
        advantage: torch.Tensor | None = batch.get('advantage')  # (batch_size, len)
        truncated: torch.Tensor = batch['truncated']
        vf_loss = vf_step.pop("loss")

        if advantage is None:
            return vf_loss

        entropy_coef = self.args.entropy_coef
        kl_coef = self.args.kl_coef

        logits = self.llm.model.forward(tokens)  # (b, len, vocab_size)
        logits = preprocess_logits(logits, regularize=True, shift_next=True)
        pr = probabilities(logits, tokens, temperature, regularize=False, shift_next=False)
        if entropy_coef:
            # we seek to maximize the entropy under temperature=1 to encourage exploration.
            entropy = entropies(logits, 1, regularize=False, shift_next=False)  
        else:
            entropy = None
        
        # compute advantage item
        q = pr / pr_old  # (b, len)
        assert not torch.any(q[output_mask].isnan())
        
        if self.args.norm_adv == 'batch':
            if advantage.ndim == 2:
                advantage = normalize_tensor(advantage, output_mask)
            else:
                advantage = normalize_tensor(advantage)
        
        if advantage.ndim == 1:
            advantage = advantage.unsqueeze(1)

        # ignore advantage of truncated steps.
        if self.args.ignore_truncated:
            advantage = advantage.masked_fill(truncated.unsqueeze(1), 0)

        if (clip := self.args.clip) is None:
            adv_item = q * advantage  # (b, len)
        else:
            clip_q = q.clip(1 - clip, 1 + clip)  # (b, len)
            adv_item = torch.fmin(q * advantage, clip_q * advantage)  # (b, len)
    
        # compute KL item
        if pr_ref is None:  # use the old model as the reference model.
            q_ref = 1/q
        else:
            q_ref = pr_ref / pr
            if __debug__:
                assert not torch.any(q_ref[output_mask].isnan())

        kl_penalty = q_ref - q_ref.log() - 1  # (b, len)
        actor_target = (adv_item - kl_coef * kl_penalty)  # (b, len)
        
        if entropy is not None:
            actor_target = actor_target + entropy_coef * entropy
            self.log('entropy', entropy, prog_bar=False)

        actor_loss = -(actor_target[output_mask].mean())
        self.log('Lpi', actor_loss, prog_bar=True)
       
        w_a = self.args.actor_weight
        w_c = self.args.critic_weight
        loss = (actor_loss if w_a == 1 else w_a * actor_loss) + (vf_loss if w_c == 1 else w_c * vf_loss)
        self.log('L', loss, prog_bar=False)

        @torch.inference_mode()
        def debug_info(i: int):
            tokenizer = self.llm.tokenizer
            v = vf_step["vf_pred"]
            target_v = vf_step["vf_target"]
            return {
                'text': tokenizer.decode(tokens[i, text_mask[i]]),
                'reward': batch['rewards'][i],
                'adv': None if advantage is None else (
                    advantage[i, output_mask[i]] if advantage.size(1) == output_mask.size(1)
                    else advantage[i]
                ),
                'v': v[i, text_mask[i]],
                'target_v': target_v[i],
                'pr': pr[i, output_mask[i]],
                'quotient': q[i, output_mask[i]]
            }

        return loss

    def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
        self.update_vf_target()
    
    @property
    def warming_up(self):
        return self.current_epoch + 1 < self.args.epochs_warmup
    
    @property
    def cooling_down(self):
        return self.current_epoch + 1 >= self.args.epochs - self.args.epochs_cooldown


@dc.dataclass
class PPO(RLTuning[PPOArgs, LitPPO]):
    """
    Attributes:
        data_path (Path | str): where to load the data
    """
    
    def _get_lit_module(self):
        return LitPPO(
            self.args,
            self.impl,
            self.task,
            self.train_data,
            self.checkpoint,
            self.tokenizer_dir,
            self.reflection,
            self.trainer_ckpt_path,
            self.distribute,
        )


def _stopped_mask(lengths: torch.Tensor, maxlen: int):
    ts = torch.arange(maxlen, dtype=torch.int64, device=lengths.device)
    return ts >= (lengths.unsqueeze(-1))
