import dataclasses as dc
from pathlib import Path
import lightning as L
import torch
import numpy as np
from core.reasoning import ReasoningTask, CoT, reasoners

from core.model import LitLLM, ValueModel, Distrubute, make_llm, LLM
from core.utils.th import probabilities, normalize_tensor, entropies, preprocess_logits, fetch
from .base import RLTuning, RLArgs, LitRL
from .utils import Collector, ReplayDataset, estimate_values

from typing import Literal, NamedTuple, ClassVar


@dc.dataclass
class ValueIterArgs(RLArgs):
    """
    Arguments for Value Iteration.
    """

    discount: float = 1
    prompt_value_weight: float = 0
    output_value_weight: float = 0
    norm_reward: bool = False
    target_method: Literal['mc', 'td'] = 'mc'
    ignore_truncated: bool = True
    use_target_vf: bool = False
    target_update_rate: float | None = None


class LitValueIteration[Args: ValueIterArgs](LitRL[Args, Collector]):
    """
    The RL process that involves an actor and a critic.
    """
    
    _llm_trainable: ClassVar[bool] = False
    _vf_trainable: ClassVar[bool] = True
    _require_logits: ClassVar[bool] = False
    _require_probs: ClassVar[bool] = False
    _log_vf_error: bool = True

    def __init__(self,
        args: Args,
        impl: reasoners.ThoughtImpl,
        task: ReasoningTask,
        train_samples: list[CoT] | None,
        checkpoint_dir: Path | str,
        tokenizer_dir: Path | str | None = None,
        reflection: dict | None = None,
        trainer_ckpt_path: Path | str | None = None,
        distribute: Distrubute = 'auto',
        init_vf: Path | str | None = None,
    ):
        super().__init__(
            args, impl, task, train_samples, checkpoint_dir, tokenizer_dir,
            reflection, trainer_ckpt_path, distribute,
        )

        self._multi_contextual = self._reasoner.impl_info.multi_contextual

        fabric = self.llm.fabric
        with fabric.init_module(empty_init=(fabric.world_size > 1)):
            self.vf = ValueModel(self.llm.model.config)
            # To reduce computational cost, the target network is optional.
            if self.args.use_target_vf:
                self._vf_target = ValueModel(self.llm.model.config)
            else:
                self._vf_target = None

        if init_vf is not None:
            init_vf = Path(init_vf)
            if init_vf.is_file():
                fabric.load_raw(init_vf, self.vf)
            elif (vf_path := init_vf / 'vf.pth').exists() and vf_path.is_file():
                fabric.load_raw(vf_path, self.vf)
            else:
                raise FileNotFoundError(init_vf)
        else:
            self.vf.init_weights(self.llm)

        if self._vf_target is not None:
            self._vf_target.load_state_dict(self.vf.state_dict())
            self._vf_target.eval()

    @property
    def vf_target(self):
        return self.vf if self._vf_target is None else self._vf_target

    @property
    def _trainable_modules(self):
        modules: dict[str, torch.nn.Module] = {}
        if self._llm_trainable:
            modules.update(super()._trainable_modules)
        if self._vf_trainable:
            modules["vf"] = self.vf
        return modules
    
    def _get_train_collector(self) -> Collector:
        return Collector(
            self._collect_device,
            require=Collector.Require(
                content=True,
                truncated=True,
                terminated=True,
                logits=self._require_logits,
                probs=self._require_probs,
                rewards=True,
            ),
            allow_abortion=self.args.enable_abortion,
        )
    
    @torch.no_grad()
    def update_vf_target(self, hard=False):
        rate = self.args.target_update_rate
        if self._vf_target is not None:
            assert rate is not None
            for name, p in self._vf_target.named_parameters():
                p_ = self.vf.get_parameter(name).to(dtype=p.dtype, device=p.device)
                if hard:  # use hard update
                    p[:] = p_
                else:
                    p[:] = (1 - rate) * p + rate * p_
        else:
            return
    
    @torch.inference_mode()
    def _process_collected_data(self, _first_loading: bool = False):
        super()._process_collected_data(_first_loading)

        print("⏳ Preparing critic (value model) data", end='\r')

        device = self._collect_device
        data = self._train_collector.data
        n = len(data)
        discount = self.args.discount
        multi_contextual = self._multi_contextual
        target_method = self.args.target_method

        assert n > 0

        # normalize rewards if needed
        if self.args.norm_reward:
            r = torch.tensor(data.rewards, dtype=torch.float64)
            r = normalize_tensor(r)
            data.rewards = r.tolist()
        
        # compute value targets
        if target_method == 'td':
            # TD targets for learning value function is computed in each training step.
            # If multi-contextual, the next prompts are needed.
            if multi_contextual:
                tokens_next = data.tokens[1:]
                tokens_next.append(
                    torch.full_like(data.tokens[0], self._reasoner.session.pad_index)
                )
                prompt_lengths_next = data.prompt_lengths[1:]
                prompt_lengths_next.append(0)
                for i in range(n):
                    if data.terminated[i] or data.truncated[i]:
                        prompt_lengths_next[i] = 0
                data.tokens_next = tokens_next
                data.prompt_lengths_next = prompt_lengths_next
        
        if target_method == 'mc' or self._log_vf_error:
            _dtype = self._reward_model.dtype

            if multi_contextual:
                returns = torch.zeros(n, dtype=_dtype, device=device)
                cum_r: float = 0.
                for i in reversed(range(n)):
                    if data.terminated[i] or data.truncated[i]:
                        cum_r = 0.
                    cum_r = discount * cum_r + data.rewards[i]
                    returns[i] = cum_r
            else:
                returns = torch.tensor(data.rewards, _dtype, device)  # [n]

            data.returns = returns
        else:
            assert False

        # useful masks
        ctxlen = self.context_length
        temp = torch.arange(ctxlen, dtype=torch.int64, device=device)
        lengths = torch.tensor(data.lengths, device=device)
        prompt_lengths = torch.tensor(data.prompt_lengths, device=device)
        text_mask = temp < lengths.unsqueeze(1)
        output_mask = text_mask & (temp >= prompt_lengths.unsqueeze(-1))
        data.text_mask = text_mask
        data.output_mask = output_mask
        del temp, lengths, prompt_lengths

        print("[✔] Preparing critic (value model) data")

    def _debug_data(self, i: int):
        d = self._train_collector.data[i]
        length = d['lengths']
        prompt_length = d['prompt_lengths']
        tokens = d['tokens']
        prompt = self.llm.tokenizer.decode(tokens[:prompt_length])
        output = self.llm.tokenizer.decode(tokens[prompt_length:length])
        d['_prompt'] = prompt
        d['_output'] = output
        return d

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx):

        device = self.device
        for k, v in batch.items():
            batch[k] = v.to(device=device)

        return_dict: dict[str, torch.Tensor] = {}

        T: torch.Tensor = batch['lengths']  # (batch_size)
        T0: torch.Tensor = batch['prompt_lengths']  # (batch_size)
        tokens: torch.Tensor = batch['tokens']  # (batch_size, len)
        text_mask: torch.Tensor = batch['text_mask']  # (batch_size, len)
        rewards: torch.Tensor = batch['rewards']
        truncated: torch.Tensor = batch['truncated']
        
        self.llm.train()
        self.vf.train()

        # compute critic (value model) target
        v = self.vf.forward(tokens)
        target_method = self.args.target_method
        target_v: torch.Tensor
        if target_method == 'mc':
            target_v = batch['returns'].unsqueeze(1)
            assert target_v.ndim == 2
        elif target_method == 'td':
            if self._multi_contextual:
                tokens_next: torch.Tensor = batch['tokens_next']
                prompt_lengths_next: torch.Tensor = batch['prompt_lengths_next']
                max_next_prompt_length = int(prompt_lengths_next.max())
                with torch.no_grad():
                    if max_next_prompt_length == 0:  # all truncated or terminated
                        target_v = rewards.unsqueeze(1)
                    else:
                        no_next_prompt = (prompt_lengths_next == 0)
                        tokens_next = tokens_next[:, :max_next_prompt_length]
                        v_next = self.vf_target.forward(tokens_next)  # [batchsize, len]
                        _t = (prompt_lengths_next - 1).clamp_min(0)
                        v_next = fetch(v_next, _t, keep_dim=True)  # [batchsize, 1]
                        v_next[no_next_prompt] = 0
                        target_v = self.args.discount * v_next + rewards.unsqueeze(1)
            else:
                with torch.no_grad():
                    if self._vf_target is None:
                        v_ = v.detach()
                    else:
                        v_ = self._vf_target.forward(tokens)
                    target_v = unictx_td_target(
                        v_, rewards, T,
                        done_mask=~text_mask,
                        discount=self.args.discount,
                    )
        else:
            assert False
        
        sqe = 0.5 * torch.square(target_v - v)  # square error
        if self.args.ignore_truncated and self._multi_contextual:
            sqe = sqe.masked_fill(truncated.unsqueeze(1), 0)

        if self._log_vf_error:  # log the monte-carlo error
            returns = batch['returns']
            assert returns.ndim == 1
            with torch.no_grad():
                error_mc = torch.abs(returns.unsqueeze(-1) - v)
                self.log("Evf", error_mc[text_mask].mean(), prog_bar=True)
                self.log("Evp", fetch(error_mc, T - 1).mean(), prog_bar=True)
                self.log("Evo", fetch(error_mc, T0 - 1).mean(), prog_bar=True)

        return_dict["vf_pred"] = v
        return_dict["vf_target"] = target_v
        
        vf_loss = sqe[text_mask].mean()
        if (w := self.args.output_value_weight):
            vf_loss_out = sqe.gather(1, (T - 1).unsqueeze(1)).mean()
            return_dict["vf_loss_output"] = vf_loss_out
            self.log("Lvo", vf_loss_out, prog_bar=True)
            vf_loss = vf_loss + w * vf_loss_out
        if (w := self.args.prompt_value_weight):
            vf_loss_prompt = sqe.gather(1, (T0 - 1).unsqueeze(1)).mean()
            return_dict["vf_loss_prompt"] = vf_loss_prompt
            self.log("Lvp", vf_loss_prompt, prog_bar=True)
            vf_loss = vf_loss + w * vf_loss_prompt

        return_dict["loss"] = vf_loss
        self.log('Lvf', vf_loss, prog_bar=True)

        @torch.inference_mode()
        def debug_info(i: int):
            tokenizer = self.llm.tokenizer
            return {
                'text': tokenizer.decode(tokens[i, text_mask[i]]),
                'reward': batch['rewards'][i],
                'v': v[i, text_mask[i]],
                'target_v': target_v[i],
            }

        return return_dict

    def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
        self.update_vf_target()


def _stopped_mask(lengths: torch.Tensor, maxlen: int):
    ts = torch.arange(maxlen, dtype=torch.int64, device=lengths.device)
    return ts >= (lengths.unsqueeze(-1))


@dc.dataclass
class ValueEstimation(RLTuning[ValueIterArgs, LitValueIteration]):

    def _get_lit_module(self):
        return LitValueIteration(
            self.args,
            self.impl,
            self.task,
            self.train_data,
            self.checkpoint,
            self.tokenizer_dir,
            self.reflection,
            self.trainer_ckpt_path,
            self.distribute,
        )


def unictx_td_target(
    v: torch.Tensor,
    r: torch.Tensor,
    lengths: torch.Tensor,
    done_mask: torch.Tensor | None,
    discount: float,
):
    b = v.size(0)
    td_target = torch.zeros_like(v)

    if done_mask is None:
        ts = torch.arange(v.size(-1), dtype=torch.int64, device=v.device)
        done_mask = ts >= (lengths.unsqueeze(-1))

    next_v = v.masked_fill(done_mask, 0)[:, 1:]
    td_target[:, :-1] = discount * next_v  # target[t] = discount * value[t+1]
    if r.ndim == 1:  # outcome supervision
        td_target[range(b), lengths - 1] = r  # target[T-1] = reward
    # elif r.ndim == 2:  # process supervision
    #     td_target += r  # target[t] = reward[t] + discount * value[t+1]
    else:
        raise IndexError

    return td_target


def unictx_td_error(
    vf: ValueModel,
    data: ReplayDataset,
    batch_size: int,
    discount: float,
    device: torch.device | str | None,
):

    v = estimate_values(vf, data, batch_size, device)  # [n_seq, context_length]
    assert v.ndim == 2

    rewards = torch.tensor(data.rewards, dtype=v.dtype, device=v.device)
    lengths = torch.tensor(data.lengths, dtype=torch.int64, device=v.device)
    done = _stopped_mask(lengths, v.size(1))
    td_target = unictx_td_target(v, rewards, lengths, done, discount=discount)
    error = torch.where(done, 0, td_target - v)

    return error


def multictx_td_error(
    vf: ValueModel,
    data: ReplayDataset,
    batch_size: int,
    discount: float,
    device: torch.device | str | None,
):
    v = estimate_values(vf, data, batch_size, device)

    terminated = torch.tensor(data.terminated, dtype=torch.bool, device=v.device)
    truncated = torch.tensor(data.truncated, dtype=torch.bool, device=v.device)
    done = terminated | truncated
    rewards = torch.tensor(data.rewards, dtype=v.dtype, device=v.device)
    prompt_lengths = torch.tensor(data.prompt_lengths, dtype=torch.int64, device=v.device)
    td_target = multictx_td_target(v, rewards, done, prompt_lengths, discount)
    lengths = torch.tensor(data.lengths, dtype=torch.int64, device=v.device)
    stopped = _stopped_mask(lengths, v.size(1))
    error = torch.where(stopped, 0, td_target.unsqueeze(1) - v)

    return error


def multictx_td_target(
    v: torch.Tensor,  # [b, len]
    r: torch.Tensor,  # [b]
    done: torch.Tensor,
    prompt_lengths: torch.Tensor,
    discount: float,
):
    v_prompt = fetch(v, prompt_lengths - 1)  # [b]
    next_v = torch.zeros_like(v_prompt)
    next_v[:-1] = v_prompt[1:]
    next_v.masked_fill_(done, 0)
    td_target = discount * next_v + r
    return td_target  # [b]
