import dataclasses as dc
from pathlib import Path
import lightning as L
import torch
import numpy as np
from core.reasoning import ReasoningTask, CoT, reasoners

from core.model import Distrubute, make_llm
from core.utils.th import probabilities, preprocess_logits, normalize_tensor
from core.utils.iterate import indices as iterate_indices
from .base import RLTuning, RLArgs, LitRL
from .utils import Collector, ReplayDataset, estimate_probs, InferenceBuffer

from typing import Literal, cast


@dc.dataclass
class GRPOArgs(RLArgs):

    group_size: int = 8
    clip: float = 0.1
    kl_coef: float = 0.1
    discount: float = 1
    use_ref_model: bool = True
    collect_shape: tuple[int] = dc.field(init=False, default=NotImplemented)

    def __post_init__(self):
        self.collect_shape = (self.group_size,)
        super().__post_init__()

class GroupReplayDataset(ReplayDataset):
    norm_rewards: list[float]  # normalized rewards


class GroupCollector(Collector):

    _data: GroupReplayDataset
    data: GroupReplayDataset

    def __init__(
        self,
        group_dims: int | tuple[int, ...] = -1, 
        device: torch.device | str | None = None,
        require: Collector.Require = Collector.Require(), 
        *,
        prob_bias: float = 0.01,
        allow_abortion: bool = False,
    ):
        super().__init__(device, require, prob_bias=prob_bias, allow_abortion=allow_abortion)
        self._group_dims = (group_dims,) if isinstance(group_dims, int) else group_dims
    
    def _init_data(self):
        return GroupReplayDataset()

    def reset(self):
        super().reset()
        self._data.norm_rewards = []

    def after_reasoning(self, thought, outcome):
        data = self._data
        terminated = self.host.terminated
        trajectory = self.trajectory
        if len(trajectory) == 0:
            return
        shape = trajectory[0].shape
        assert all(seq.shape == shape for seq in trajectory)
        all_dims = range(len(shape))
        group_dims = set(all_dims[d] for d in self._group_dims)
        batch_dims = tuple(d for d in all_dims if d not in group_dims)

        # compute outcome rewards and normalize rewards
        outcome_rewards = None
        norm_process_rewards = None
        norm_outcome_rewards = None
        if self.require.rewards:
            if (rm := self.reward_model) is None:
                raise ValueError("Rewards are required while missing a reward model.")
            outcome_rewards = rm.outcome_rewards(
                self.host.ref, self.host.llm.preprocessor, outcome, terminated
            ).to(device=self._device)
            # compute normalized rewards
            norm_process_rewards, norm_outcome_rewards = _normalize_rewards(
                trajectory, terminated, outcome_rewards, shape, batch_dims
            )
        
        for idx in iterate_indices(shape):
            _case_is_not_empty = False
            for t, buf in enumerate(trajectory):
                valid = bool(buf.mask[idx])
                if not valid:
                    continue
                _case_is_not_empty = True

                if self.require.content:
                    tokens = buf.tokens[idx].clone()
                    length = int(buf.lengths[idx])
                    prompt_length = int(buf.prompt_lengths[idx])
                    assert tokens.shape == (self._context_length,)
                    data.tokens.append(tokens)
                    data.lengths.append(length)
                    data.prompt_lengths.append(prompt_length)
                if self.require.terminated:
                    data.terminated.append(False)
                if self.require.truncated:
                    data.truncated.append(False)
                if self.require.logits:
                    assert buf.logits is not None
                    data.logits.append(buf.logits[idx].clone())
                if self.require.probs:
                    assert buf.probs is not None
                    data.probs.append(buf.probs[idx].clone())
                if self.require.rewards:
                    assert buf.process_rewards is not None
                    assert norm_process_rewards is not None
                    data.rewards.append(float(buf.process_rewards[idx]))
                    data.norm_rewards.append(float(norm_process_rewards[(t,) + idx]))
            if _case_is_not_empty:
                if self.require.truncated:
                    data.truncated[-1] = not bool(terminated[idx])
                if self.require.terminated:
                    data.terminated[-1] = bool(terminated[idx])
                if self.require.rewards:
                    assert outcome_rewards is not None
                    assert norm_outcome_rewards is not None
                    data.rewards[-1] += float(outcome_rewards[idx])
                    data.norm_rewards[-1] += (float(norm_outcome_rewards[idx]))


class LitGRPO(LitRL[GRPOArgs, GroupCollector]):

    def __init__(self,
        args: GRPOArgs,
        impl: reasoners.ThoughtImpl,
        task: ReasoningTask,
        train_samples: list[CoT] | None,
        checkpoint_dir: Path | str,
        tokenizer_dir: Path | str | None = None,
        reflection: dict | None = None,
        trainer_ckpt_path: Path | str | None = None,
        distribute: Distrubute = 'auto',
    ):
        super().__init__(
            args, impl, task, train_samples, checkpoint_dir, tokenizer_dir,
            reflection, trainer_ckpt_path, distribute
        )

        self._multi_contextual = self._reasoner.impl_info.multi_contextual

        # reference model
        if self.args.use_ref_model:
            self.ref_llm = make_llm(
                str(checkpoint_dir),
                tokenizer_dir=tokenizer_dir,
                distribute=distribute,
            )
            self.ref_llm.eval()
    
    def _get_train_collector(self):
        return GroupCollector(
            1, self._collect_device,
            require=GroupCollector.Require(
                content=True,
                truncated=True,
                terminated=True,
                logits=False,
                probs=True,
                rewards=True,
            ),
            allow_abortion=self.args.enable_abortion,
        )

    @torch.inference_mode()
    def _process_collected_data(self, _first_loading: bool = False):
        super()._process_collected_data(_first_loading)

        print("⏳ Preparing training epoch", end='\r')

        device = self._collect_device
        data = self._train_collector.data
        n = len(data)
        assert n > 0

        # estimate reference probabilities
        data.ref_probs = estimate_probs(data, self.ref_llm, self.temperature,
                                        self.args.inference_batch_size, device)

        # compute advantage
        discount = self.args.discount
        advs = [0. for _ in range(n)]
        adv = 0.
        for t in reversed(range(n)):
            if data.terminated[t] or data.truncated[t]:
                adv = 0.
            adv = adv * discount + data.norm_rewards[t]
            advs[t] = adv
        data.advantage = advs
        
        # useful masks
        ctxlen = self.context_length
        temp = torch.arange(ctxlen, dtype=torch.int64, device=device)
        lengths = torch.tensor(data.lengths, device=device)
        prompt_lengths = torch.tensor(data.prompt_lengths, device=device)
        text_mask = temp < lengths.unsqueeze(1)
        output_mask = text_mask & (temp >= prompt_lengths.unsqueeze(-1))
        data.text_mask = text_mask
        data.output_mask = output_mask
        del temp, lengths, prompt_lengths

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx):
        temperature: float = self.temperature
        tokens: torch.Tensor = batch['tokens']  # (batch_size, group_size, len)
        pr_old: torch.Tensor = batch['probs']  # (batch_size, len)
        pr_ref: torch.Tensor | None = batch.get('ref_probs')  # (batch_size, group_size, len)
        output_mask: torch.Tensor = batch['output_mask']  # (batch_size, group_size, len)
        text_mask: torch.Tensor = batch['text_mask']  # (batch_size, group_size, len)
        advantage: torch.Tensor = batch['advantage']  # (batch_size)

        kl_coef = self.args.kl_coef
        logits = self.llm.model.forward(tokens)  # (b, len, vocab_size)
        logits = preprocess_logits(logits, regularize=True, shift_next=True)
        pr = probabilities(logits, tokens, temperature, regularize=False, shift_next=False)
        
        # compute advantage item
        q = pr / pr_old  # (b, len)
        assert not torch.any(q[output_mask].isnan())
        
        assert advantage.ndim == 1
        advantage = advantage.unsqueeze(1)

        if (clip := self.args.clip) is None:
            adv_item = q * advantage  # (b, len)
        else:
            clip_q = q.clip(1 - clip, 1 + clip)  # (b, len)
            adv_item = torch.fmin(q * advantage, clip_q * advantage)  # (b, len)
    
        # compute KL item
        if pr_ref is None:  # use the old model as the reference model.
            q_ref = 1/q
        else:
            q_ref = pr_ref / pr
            if __debug__:
                assert not torch.any(q_ref[output_mask].isnan())

        kl_penalty = q_ref - q_ref.log() - 1  # (b, len)
        actor_target = (adv_item - kl_coef * kl_penalty)  # (b, len)
        actor_loss = -(actor_target[output_mask].mean())
        self.log('Lpi', actor_loss, prog_bar=True)

        @torch.inference_mode()
        def debug_info(i: int):
            tokenizer = self.llm.tokenizer
            return {
                'text': tokenizer.decode(tokens[i, text_mask[i]]),
                'reward': batch['rewards'][i],
                'adv': None if advantage is None else (
                    advantage[i, output_mask[i]] if advantage.size(1) == output_mask.size(1)
                    else advantage[i]
                ),
                'pr': pr[i, output_mask[i]],
                'quotient': q[i, output_mask[i]]
            }

        return actor_loss


@dc.dataclass
class GRPO(RLTuning[GRPOArgs, LitGRPO]):
    """
    Attributes:
        data_path (Path | str): where to load the data
    """
    
    def _get_lit_module(self):
        return LitGRPO(
            self.args,
            self.impl,
            self.task,
            self.train_data,
            self.checkpoint,
            self.tokenizer_dir,
            self.reflection,
            self.trainer_ckpt_path,
            self.distribute,
        )


def _normalize_rewards(trajectory: list[InferenceBuffer], 
                       terminated: torch.Tensor, outcome_rewards: torch.Tensor,
                       shape: tuple[int, ...], batch_dims: tuple[int, ...]):

    norm_outcome_rewards = outcome_rewards.clone()
    norm_process_rewards = []
    masks = []
    for buf in trajectory:
        assert buf.process_rewards is not None
        norm_process_rewards.append(buf.process_rewards)
        masks.append(buf.mask)
    norm_process_rewards = torch.stack(norm_process_rewards)  # [n_step, *shape]
    masks = torch.stack(masks)  # [n_step, *shape]
    for idx_g in iterate_indices(shape, batch_dims):  # iterate the indices of groups
        idx_g_process: tuple[slice | int, ...] = (slice(None), *idx_g)
        # norm_outcome_rewards[idx_g] = normalize_tensor(norm_outcome_rewards[idx_g], terminated[idx_g])
        norm_outcome_rewards[idx_g] = normalize_tensor(norm_outcome_rewards[idx_g])
        norm_process_rewards[idx_g_process] = normalize_tensor(
            norm_process_rewards[idx_g_process], masks[idx_g_process]
        )
    del masks
    return norm_process_rewards, norm_outcome_rewards
