import dataclasses as dc
from pathlib import Path
from typing import Literal, Any, Final, cast
import re
import lightning as L
import torch
import torch.utils.data as torch_data
import os
import numpy as np
import random
import abc

from litgpt.tokenizer import Tokenizer
from litgpt.utils import extend_checkpoint_dir
from lightning.pytorch.callbacks import ModelCheckpoint

from core.reasoning.task import Split, ReasoningTask
from core.reasoning import CoT, make_reasoner, ReflectionApproach, ThoughtImpl, RewardModel
from core.inference.sampling import Sampling
from core.utils.progress import Progress
from core.utils.th import ListDataset
from core.model import make_llm, LitLLM, Distrubute, Precision, Auto

from .utils import split_data, average_return, Collector


type Supervision = Literal['outcome', 'process', 'success']
type Data = dict[str, torch.Tensor]
OptimArgs = LitLLM.OptimArgs


@dc.dataclass
class RLArgs:

    epochs: int
    train_epoch_size: int = 1024  # how many sequences are collected in each epoch?
    train_epoch_repeat: int = 1  # how many epoches use the same data before collecting new data?
    train_batch_size: int = 64
    inference_batch_size: int = 256
    context_length: int = 128
    supervision: Literal['outcome', 'process', 'success'] = 'outcome'
    source_of_inputs: Literal["data", "task"] = "data"
    temperature: float | Progress = 1.0
    enable_abortion: bool = False  # whether to early truncate the reasoning process when detecting step-wise errors.
    top_k: int | None = None
    top_p: float = 1.0
    max_steps: int | None = None
    optim: OptimArgs = dc.field(default_factory=OptimArgs)
    max_norm: float = 1.
    log_interval: int = 10
    max_val_size: int | None = None
    val_scale: int = 1
    trainer_args: dict = dc.field(default_factory=dict)

    collect_shape: tuple[int, ...] | int = dc.field(default=1)
    """For each training query, how many samples are collected in each epoch, 
    and how are these samples shaped in tensors.
    For example, if `inference_batch_size = 24` and `collect_shape = (3, 2)`, then each
    batch of collected samples come from `24/3/2=4` queries, and will be shaped as `(4, 3, 2)`.
    """

    collect_to_cpu: bool = False
    save_each_n_epochs: int | None = None
    save_top_k: int = 0
    save_last: bool = True
    val_each_n_epochs: int | None = None
    val_temperature: float = -1

    collect_scale: int = dc.field(init=False)
    collect_batch_size: int = dc.field(init=False)
    val_batch_size: int = dc.field(init=False)
    
    def __post_init__(self):
        self.collect_scale = int(np.prod(self.collect_shape))
        self.collect_batch_size = max(self.inference_batch_size // self.collect_scale, 1)
        self.val_batch_size = max(self.inference_batch_size // self.val_scale, 1)
        if self.val_temperature < 0:
            self.val_temperature = (
                self.temperature.end_value if isinstance(self.temperature, Progress)
                else self.temperature
            )


class param_for_refl:
    """
    The hyperparameters of reflection for RL.
    """

    @staticmethod
    def revise_error(
        reject_coef: float = 0,
        reflect_temperature: float = 1,
    ):
        approach: ReflectionApproach = "self-verify"
        return locals()



class LitRL[Args: RLArgs, TrainCollector: Collector](LitLLM, abc.ABC):

    _ref_dicts: list[dict[str, Any]]

    def _init_env(self, impl: ThoughtImpl, reflection: dict | None):
        from core.reasoning.reflection.self_verify import SVRewardWrapper
        from core.reasoning import param

        task = self._task
        args = self.args

        rm: RewardModel = task.reward_model(self.args.supervision)
        if reflection is None:
            reflargs = None
        else:
            approach: ReflectionApproach = reflection["approach"]
            if approach == "self-verify":
                rm = SVRewardWrapper(rm, reflection["reject_coef"])
                reflargs = param.reflection.self_verify(
                    reject_mode="retry",
                    reflect_temperature=reflection["reflect_temperature"]
                )
            else:
                raise NotImplementedError(f"RL Training with reflection (approach={approach}) is not supported.")

        reasoner = make_reasoner(
            self.llm, impl, args.context_length,
            task=task,
            max_steps=self.args.max_steps,
            reflection_args=reflargs,
        )
        return reasoner, rm

    def __init__(
        self,
        args: Args,
        impl: ThoughtImpl,
        task: ReasoningTask,
        train_samples: list[CoT] | None,
        checkpoint_dir: Path | str,
        tokenizer_dir: Path | str | None = None,
        reflection: dict | None = None,
        trainer_ckpt_path: Path | str | None = None,
        distribute: Distrubute = 'auto',
    ):
        super().__init__(
            checkpoint_dir,
            tokenizer_dir,
            trainer_ckpt_path,
            args.optim,
            distribute,
        )
        self.args: Final = args
        self._samples = train_samples
        self._task = task
        self._reasoner, self._reward_model = self._init_env(impl, reflection)
        self._reasoner.ref = self._ref  # reference function for reward computing

        self._collect_device = torch.device('cpu') if args.collect_to_cpu else None
        self._val_collector = Collector(
            device=self._collect_device,
            require=Collector.Require(
                content=False,
                truncated=True,
                terminated=True,
                logits=False,
                probs=False,
                rewards=True,
            ),
            allow_abortion=args.enable_abortion,
        )
        self._val_collector.reward_model = self._reward_model
        self._train_collector = self._get_train_collector()
        self._train_collector.reward_model = self._reward_model
        self._repeated_epochs: int = 0
        self._enforce_temperature: float | None = None
    
    @property
    def _inference(self) -> Sampling:
        return cast(Sampling, self._reasoner.inference)

    @property
    def context_length(self):
        return self._reasoner.session.context_length

    @property
    def temperature(self) -> float:
        temp = self._inference.temperature
        assert not isinstance(temp, torch.Tensor)
        return temp
    
    @temperature.setter
    def temperature(self, value: float):
        assert isinstance(self._inference, Sampling)
        self._inference.temperature = value

    def _get_train_temperature(self, epoch: int | None = None)-> float:
        if epoch is None:
            epoch = self.current_epoch
        if isinstance(self.args.temperature, Progress):
            return self.args.temperature(epoch / self.args.epochs)
        else:
            return self.args.temperature
    
    def _ref(self, idx: tuple[int, ...]):
        i = idx[0] % len(self._ref_dicts)
        return self._ref_dicts[i]
    
    @abc.abstractmethod
    def _get_train_collector(self) -> TrainCollector:
        raise NotImplementedError
    
    def get_train_dataloader(self):
        if len(self._train_collector.data) == 0:
            self._collect_train_epoch()
        return torch_data.DataLoader(
            self._train_collector.data,
            self.args.train_batch_size,
            shuffle=True,
        )
    
    def on_validation_epoch_start(self) -> None:
        self.temperature = self.args.val_temperature
        self._val_collector.attach(self._reasoner)
    
    @torch.inference_mode()
    def validation_step(self, batch: list[CoT], batch_idx):
        data = self._task
        reasoner = self._reasoner
        scale = self.args.val_scale

        inputs = [data.style.apply_input(**cot.as_dict()) for cot in batch]
        self._ref_dicts = [data.get_ref_dict(cot) for cot in batch]
        input_tokens = reasoner.preprocess(inputs)
        if scale > 1:
            input_tokens = input_tokens.repeat(scale)

        reasoner.__call__(input_tokens)

        del self._ref_dicts
        return None
    
    def on_validation_epoch_end(self) -> None:
        self._reasoner.close()
        self._val_collector.detach()
        ret = average_return(self._val_collector.data)
        rew = float(np.mean(self._val_collector.data.rewards))
        self.log("rt", ret, prog_bar=True)  # return
        self.log("rw", rew, prog_bar=True)  # reward

    def on_train_epoch_end(self) -> None:
        self._repeated_epochs += 1
        if self._repeated_epochs >= self.args.train_epoch_repeat:
            self._train_collector.reset()
            self._collect_train_epoch()
            self._process_collected_data(False)
            self._repeated_epochs = 0
    
    def on_train_start(self) -> None:
        self._process_collected_data(True)

    def _sample_train_batch(self, size: int):
        if self._samples is None:
            return self._task.samples("train", size, require_thought=False)
        else:
            return random.choices(self._samples, k=size)
    
    @torch.inference_mode()
    def _collect_train_epoch(self):
        data = self._task
        reasoner = self._reasoner
        collector = self._train_collector
        batch_size = self.args.collect_batch_size
        epoch_size = self.args.train_epoch_size
        collect_shape = self.args.collect_shape
        
        if self._enforce_temperature is not None:
            self.temperature = self._enforce_temperature
        else:
            self.temperature = self._get_train_temperature()
        
        self.eval()
        collector.attach(reasoner)

        print()
        for i in range(0, epoch_size, batch_size):
            print(f"⏳ Collecting reasoning samples: {i} / {epoch_size}", end='\r')
            k = min(batch_size, epoch_size - i)
            batch = self._sample_train_batch(k)
            inputs = [data.style.apply_input(**cot.as_dict()) for cot in batch]
            self._ref_dicts = [data.get_ref_dict(cot) for cot in batch]
            input_tokens = reasoner.preprocess(inputs)
            if collect_shape == () or collect_shape == 1:
                pass
            elif isinstance(collect_shape, int):
                input_tokens = input_tokens.repeat(collect_shape)
            else:
                input_tokens = input_tokens.reshape(k, *(1 for _ in collect_shape)).expand(k, *collect_shape)
            reasoner.__call__(input_tokens)
            del self._ref_dicts

        reasoner.close()
        collector.detach()
        print(f"[✔] Collecting reasoning samples: {epoch_size} / {epoch_size}.")

    def _process_collected_data(self, _first_loading: bool = False):
        collector = self._train_collector
        avr_return = average_return(collector.data)
        avr_reward = float(np.mean(collector.data.rewards))
        print(
            f"- total steps: {len(collector.data)}",
            f"- average reward: {avr_reward}",
            f"- average return: {avr_return}",
            f"- temperature: {self.temperature}",
            sep='\n',
        )
        self.log("train_return", avr_return, prog_bar=False)
        self.log("train_reward", avr_reward, prog_bar=False)
        self.log("temperature", self.temperature, prog_bar=False)


@dc.dataclass
class RLTuning[TArg: RLArgs, Tlit: LitRL](abc.ABC):

    task: ReasoningTask
    checkpoint: str
    args: TArg
    impl: ThoughtImpl
    reflection: dict | None = None  # None to disable reflection
    out_checkpoint: str | None = None
    ref_checkpoint: str | None = None
    tokenizer_dir: Path | str | None = None
    trainer_ckpt_path: Path | str | None = None
    val_data_ratio: float = 0.2
    distribute: Distrubute = 'auto'
    
    def __post_init__(self):
        self.style = self.task.style
        
        out_checkpoint = self.out_checkpoint or self.checkpoint
        self.out_dir = extend_checkpoint_dir(Path(out_checkpoint))

        # Try to get input data
        if self.args.source_of_inputs == "data":
            train_data = self.task.get_instances('train')
            try:
                val_data = self.task.get_instances('val')
            except KeyError:
                val_data = []
            if not val_data:
                train_data, val_data = split_data(train_data, 1 - self.val_data_ratio)
            print(
                "Initializing: loaded %d training samples and %d validation samples."
                % (len(train_data), len(val_data))
            )
            self.train_data = ListDataset(train_data)
            self.val_data = ListDataset(val_data)
        elif self.args.source_of_inputs == "task":
            val_size = self.args.max_val_size
            if val_size is None:
                raise ValueError("max validation size must be assigned when inputs are sampled from task.")
            self.train_data = None
            self.val_data = CoTSamplerDataset(self.task, "val", val_size)
        
        # Lit module
        self.lit_module: Tlit = self._get_lit_module()
        if self.llm.checkpoint_dir != self.out_dir:
            self.llm.save(self.out_dir)
    
    @property
    def llm(self):
        return self.lit_module.llm
    
    @abc.abstractmethod
    def _get_lit_module(self) -> Tlit:
        raise NotImplementedError
    
    def get_data_loaders(self):
        train_loader = self.lit_module.get_train_dataloader()
        if self.args.val_each_n_epochs:
            shuffle_val_data = (self.args.max_val_size is not None) \
                and self.args.source_of_inputs == "data"
            val_loader = torch_data.DataLoader(
                self.val_data,
                self.args.val_batch_size,
                shuffle=shuffle_val_data,
                collate_fn=self.val_data.no_collate
            )
        else:
            val_loader = None
        return train_loader, val_loader

    def fit(
        self,
        precision: Precision | None = None,
        devices: list[int] | str | int = "auto",
        num_nodes: int = 1,
        resume: Auto | Path | bool = False,
    ):
        ckpt_path = self._get_resume_ckpt_path(resume)

        # Make sure that the initial data is collected using the resumed model and temperature.
        if ckpt_path is not None:  
            ckpt = torch.load(ckpt_path, weights_only=True)
            self.lit_module._enforce_temperature = self.lit_module._get_train_temperature(ckpt['epoch'])
            self.lit_module.load_state_dict(ckpt['state_dict'])
            del ckpt
        
        train_loader, val_loader = self.get_data_loaders()
        self.lit_module._enforce_temperature = None
        trainer = self._trainer(self.args.epochs, precision, devices, num_nodes)
        trainer.fit(self.lit_module, train_loader, val_loader, ckpt_path=ckpt_path)
    
    def _get_resume_ckpt_path(self, resume: Auto | Path | bool = 'auto') -> Path | None:

        def epoch_idx(name: str):
            ckpt_pattern = r"epoch=(\d+)"
            m = re.search(ckpt_pattern, name)
            if m is None:
                return -1
            else:
                return int(m.group(1))

        if resume is False:
            return None
        elif resume is True or resume == 'auto':
            names = [name for name in os.listdir(self.out_dir) if name.endswith('.ckpt')]
            if len(names) == 0:
                return None
            elif 'last.ckpt' in names:
                return self.out_dir / 'last.ckpt'
            else:
                return self.out_dir / max(names, key=epoch_idx)
        else:
            assert isinstance(resume, Path)
            resume = resume.with_suffix('.ckpt')
            if resume.exists():
                return resume
            elif (self.out_dir / resume).exists():
                return self.out_dir / resume
            else:
                raise FileNotFoundError(str(resume))

    def _trainer_callbacks(self) -> L.Callback | list[L.Callback] | None:
        save_each_n_epochs = self.args.save_each_n_epochs
        if save_each_n_epochs:
            return AutoSave(self, save_each_n_epochs, self.args.save_top_k, self.args.save_last)

    def _trainer(
        self,
        max_epochs: int,
        precision: Precision | None = None,
        devices: list[int] | str | int = "auto",
        num_nodes: int = 1,
    ):
        val_size = self.args.max_val_size
        val_batch_size = self.args.val_batch_size
        if val_size is None:
            limit_val_batches = None
        else:
            limit_val_batches = (val_size / val_batch_size).__ceil__()
        
        return L.Trainer(
            max_epochs=max_epochs,
            precision=precision,
            log_every_n_steps=self.args.log_interval,
            gradient_clip_val=self.args.max_norm,
            devices=devices,
            num_nodes=num_nodes,
            check_val_every_n_epoch=self.args.val_each_n_epochs,
            callbacks=self._trainer_callbacks(),
            default_root_dir=self.out_dir,
            enable_checkpointing=True,
            limit_val_batches=limit_val_batches,
            num_sanity_val_steps=0,
            reload_dataloaders_every_n_epochs=self.args.train_epoch_repeat,
            **self.args.trainer_args
        )

    def save(self, checkpoint_path: Path | str | None = None):
        if checkpoint_path is None:
            checkpoint_path = self.out_dir
        else:
            checkpoint_path = extend_checkpoint_dir(Path(checkpoint_path))

        self.lit_module.save_modules(checkpoint_path)
        print(f"\nThe checkpoint has been saved to {checkpoint_path}")


class AutoSave_old(L.Callback):

    def __init__(
        self, 
        rl: RLTuning,
        interval: int | None,
        save_top: bool,
    ):

        self.interval = interval
        self.rl = rl
        self.save_top = save_top
        self._top_val_return = -np.inf

    def on_train_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
        epoch = trainer.current_epoch
        interval = self.interval
        if interval and epoch > 0 and epoch % interval == 0:
            self.rl.save()
            trainer.save_checkpoint(self.rl.out_dir / f"epoch-{epoch}.ckpt")

    def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
        val_return = float(trainer.logged_metrics["rt"])
        if val_return > self._top_val_return:
            self._top_val_return = val_return
            if self.save_top:
                trainer.save_checkpoint(self.rl.out_dir / f"top.ckpt")


class AutoSave(ModelCheckpoint):

    def __init__(
        self,
        rl: RLTuning,
        interval: int,
        save_top_k: int,
        save_last: bool,
    ):
        super().__init__(
            dirpath=rl.out_dir,
            filename="{epoch}-{rt:.3f}",
            monitor="rt",  # minitor the best validation return.
            save_top_k=save_top_k,
            mode="max",
            every_n_epochs=interval,
            save_last=save_last,
        )
        self.rl = rl

    def on_train_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
        super().on_train_epoch_start(trainer, pl_module)
        epoch = trainer.current_epoch
        interval = self._every_n_epochs
        if interval > 0 and epoch > 0 and epoch % interval == 0:
            self.rl.save()


class CoTSamplerDataset(torch_data.Dataset):

    def __init__(self, task: ReasoningTask, split: Split, max_size: int, **kwargs):
        self.task = task
        self.split = split
        self.max_size: int = max_size
        self.kwargs = kwargs
        
    def __len__(self):
        return self.max_size

    def __getitem__(self, index) -> CoT:
        return self.task.sample(self.split, **self.kwargs)

    @classmethod
    def no_collate[T](cls, x: list[T]) -> list[T]:
        return x
