from __future__ import annotations

import dataclasses as dc
import warnings
from pathlib import Path
from typing import Any, Literal, Optional, Mapping, Iterable, Final, cast, overload
import abc

import lightning as L
import torch
import torch.nn as nn
from torch.nn.modules.module import _IncompatibleKeys
import math
import itertools
import functools

from lightning.fabric.strategies import FSDPStrategy  # type: ignore
from lightning.fabric.wrappers import _FabricModule
from litgpt.api import LLM as _LitGPTLLM, Preprocessor as LitPreprocessor
from litgpt.config import Config, name_to_config
from litgpt.model import Block, GPT
from litgpt.prompts import (PromptStyle, has_prompt_style, load_prompt_style,
                            save_prompt_style)
from litgpt.utils import (
    auto_download_checkpoint, 
    copy_config_files,
    save_config,
    check_file_size_on_cpu_and_warn,
    extend_checkpoint_dir,
    get_default_supported_precision,
    parse_devices
)

from .utils.progress import Progress
from .utils import kv
from .utils.buf import TokenBuffer, SequencesLike, StrArray
from .tokenization import Tokenizer

Auto = Literal['auto']
Precision = Literal["bf16-true", "bf16-mixed", "32-true"]
Logger = Literal["wandb", "tensorboard", "csv"]
Distrubute = L.Fabric | Auto


def load_checkpoint(
    fabric: L.Fabric,
    model: torch.nn.Module,
    state_path: Path,
    strict: bool = True
) -> None:
    if isinstance(fabric.strategy, FSDPStrategy):
        fabric.load_raw(state_path, model, strict=strict)
    else:
        state_dict: dict = torch.load(state_path)
        state_dict = state_dict.get("model", state_dict)
        model.load_state_dict(state_dict, strict=strict)


@overload
def as_fabric(distribute: Distrubute, /) -> L.Fabric: ...
@overload
def as_fabric(*,
              num_nodes: int = 1,
              devices: int | str | Auto = 1,
              accelerator: str | Auto = 'auto',
              precision: str | Precision | Auto | None = 'auto',
              **kwargs) -> L.Fabric: ...
def as_fabric(distribute: Distrubute | None = None, /, *,
              num_nodes: int = 1,
              devices: int | str | Auto = 1,
              accelerator: str | Auto = 'auto',
              precision: str | Precision | Auto | None = 'auto',
              **kwargs
) -> L.Fabric:
    if isinstance(distribute, L.Fabric):
        return distribute
    elif distribute == 'auto':
        if torch.cuda.is_available():
            accelerator = "cuda"
        elif torch.backends.mps.is_available():
            accelerator = "cpu"
            # accelerator = "mps"
            warnings.warn("MPS is currently not supported. Using CPU instead.", UserWarning)
        else:
            accelerator = "cpu"
        return L.Fabric(
            accelerator=accelerator,
            devices=1,
            precision=get_default_supported_precision(training=False),  # type: ignore
        )
    else:
        assert distribute is None
        devices = parse_devices(devices)
        precision = precision or get_default_supported_precision(training=True)
        if devices * num_nodes > 1:
            strategy = FSDPStrategy(
                auto_wrap_policy={Block},
                activation_checkpointing_policy={Block},
                state_dict_type="full",
                limit_all_gathers=True,
                cpu_offload=False,
            )
        else:
            strategy = "auto"
        return L.Fabric(
            accelerator=accelerator,
            strategy=strategy,
            devices=devices,
            num_nodes=num_nodes,
            precision=precision,  # type: ignore
            **kwargs,
        )


class Preprocessor(LitPreprocessor):

    def __init__(self, tokenizer: Tokenizer, device: str = "cpu") -> None:
        super().__init__(tokenizer, device)

        assert isinstance(tokenizer, Tokenizer)
        self.tokenizer: Tokenizer

    def decode(self, token_ids: torch.Tensor, skip_special_tokens=True) -> str:
        return self.tokenizer.decode(token_ids, skip_special_tokens)


class LLM(nn.Module):

    def __init__(
        self,
        model: nn.Module,
        preprocessor: Preprocessor,
        fabric: L.Fabric,
        checkpoint_dir: Path | None = None,
    ) -> None:
        
        super().__init__()

        if (
            not isinstance(model, GPT) and (
                not isinstance(model, _FabricModule) or
                not isinstance(model._original_module, GPT)
            )
        ):
            raise ValueError("Require a GPT model.")
        
        self.model: Final[GPT] = cast(GPT, model)
        self.preprocessor: Final[Preprocessor] = preprocessor
        self.fabric = fabric
        self.checkpoint_dir = checkpoint_dir

        vocab_size = model.config.padded_vocab_size
        assert isinstance(vocab_size, int)
        if preprocessor.tokenizer.vocab_size > vocab_size:
            raise ValueError(
                "The model accepts a maximum vocabulary of %d, "
                "but the vocabulary size of the tokenizer is %d." %
                (vocab_size, preprocessor.tokenizer.vocab_size)
            )
        self.vocab_size: Final[int] = vocab_size

        device = fabric.device
        if isinstance(device, str):
            device = torch.device(device)
        self.device: Final = device
    
    def detokenize(
        self,
        seq: TokenBuffer,
        start: torch.Tensor | int = 0,
        stop: torch.Tensor | int | None = None,
        skip_special_tokens: bool = True,
    ) -> StrArray:
        decode = functools.partial(self.preprocessor.decode,
                                   skip_special_tokens=skip_special_tokens)
        return seq.detokenize(decode, start, stop)

    @property
    def tokenizer(self):
        return self.preprocessor.tokenizer

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        if destination is not None:
            return self.model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
        else:
            return self.model.state_dict(prefix=prefix, keep_vars=keep_vars)

    def load_state_dict(self, state_dict, strict=True):
        return self.model.load_state_dict(state_dict, strict=strict)

    def save(self, out_dir: str | Path) -> None:
        if not isinstance(out_dir, Path):
            out_dir = Path(out_dir)
        
        save_path = out_dir / "lit_model.pth"
        save_path.parent.mkdir(parents=True, exist_ok=True)

        if self.fabric is None:
            torch.save(self.state_dict(), save_path)
        else:
            self.fabric.save(save_path, self.state_dict())

        if self.fabric is None or self.fabric.global_rank == 0:
            # If initialization a model with random weights, the checkpoint dir can be none
            if self.checkpoint_dir is not None:
                copy_config_files(Path(self.checkpoint_dir), save_path.parent)
            else:
                save_config(self.model.config, out_dir)
    
    def get_value_model(self,
                        path: Path | str | None = None,
                        distribute: Distrubute = "auto",
                        trained_only: bool = False):
        if path is None:
            return make_vf(self, distribute, trained_only)
        else:
            if distribute == 'auto':
                distribute = self.fabric
            return make_vf(path, distribute, trained_only)


@overload
def make_llm(
    config: Config, /,
    *,
    tokenizer_dir: Path | str,
    distribute: Distrubute | dict = "auto",
) -> LLM: ...


@overload
def make_llm(
    checkpoint: Path | str, /,
    init: Optional[Literal["pretrained", "random"]] = "pretrained",
    tokenizer_dir: Optional[Path | str] = None,
    access_token: Optional[str] = None,
    distribute: Distrubute | dict = "auto",
) -> LLM: ...


def make_llm(
    source: Path | str | Config,
    init: Optional[Literal["pretrained", "random"]] = "pretrained",
    tokenizer_dir: Optional[Path | str] = None,
    access_token: Optional[str] = None,
    distribute: Distrubute | dict = "auto",
) -> LLM:
    """
    Derived from `ligpit.api.LLM`.
    """
    if isinstance(source, Config):
        config = source
        checkpoint_dir = None
    else:
        if isinstance(source, Path):
            source = str(source)
        if init == "pretrained":
            checkpoint_dir = auto_download_checkpoint(
                model_name=source,
                access_token=access_token,
                ignore_tokenizer_files=tokenizer_dir is not None
            )
            config = Config.from_file(checkpoint_dir / "model_config.yaml")
        elif init == "random":
            checkpoint_dir = None
            try:
                config =  Config.from_name(source)
            except ValueError as e:
                print(f"Model name {source} is not supported.\n")
                available_models = "\n".join(sorted(name_to_config))
                print(f"Available values:\n{available_models}")
                raise e
        else:
            allowed_init = {"pretrained", "random"}
            raise ValueError(f"Invalid init option: {init}. Must be one of {allowed_init}")

    assert config is not None

    torch.set_float32_matmul_precision("high")

    if tokenizer_dir is not None:
        tokenizer_dir = extend_checkpoint_dir(Path(tokenizer_dir))
        tokenizer = Tokenizer(tokenizer_dir)
    elif checkpoint_dir is not None:
        tokenizer = Tokenizer(checkpoint_dir)
    else:
        raise ValueError("Provide a path to a tokenizer directory via the `tokenizer_dir` setting.")

    if checkpoint_dir is not None:
        prompt_style = (
            load_prompt_style(checkpoint_dir)
            if has_prompt_style(checkpoint_dir)
            else PromptStyle.from_config(config)
        )
    else:
        prompt_style = PromptStyle.from_config(config)

    if not isinstance(distribute, dict):
        fabric = as_fabric(distribute)
        with fabric.init_module(empty_init=(fabric.world_size > 1)):
            model = GPT(config)
        model.eval()
        preprocessor = Preprocessor(tokenizer, device=fabric.device)  # type: ignore
        if checkpoint_dir is not None:
            checkpoint_path = checkpoint_dir / "lit_model.pth"
            check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
            load_checkpoint(fabric, model, checkpoint_path)
        model = fabric.setup_module(model)
    else:
        preprocessor = Preprocessor(tokenizer, device="cuda" if torch.cuda.is_available() else "cpu")
        model = None
        fabric = None
    
    # Use LitGPT api for construction of LLM
    temp = _LitGPTLLM(
        model=model,  # type: ignore
        preprocessor=preprocessor,
        prompt_style=prompt_style,
        config=config,
        checkpoint_dir=checkpoint_dir,  # type: ignore
        fabric=fabric,  # type: ignore
        generate_strategy=None,
        kv_cache_initialized=False,
        fixed_kv_cache_size=False
    )
    if isinstance(distribute, dict):
        temp.distribute(**distribute)

    assert temp.fabric is not None
    assert temp.model is not None
    assert isinstance(temp.preprocessor, Preprocessor)

    return LLM(temp.model, temp.preprocessor, temp.fabric, checkpoint_dir)


class LitLLM(L.LightningModule):
    """
    The LLM used by Lightning Training.
    """

    _OPTIMIZER_CLASSES: dict[str, type[torch.optim.Optimizer]] = {  # type: ignore
        'adam': torch.optim.Adam,  # type: ignore
        'adamw': torch.optim.AdamW  # type: ignore
    }

    @dc.dataclass
    class OptimArgs:
        """Arguments that specify the optimization algorithm."""

        optimizer: Literal['adam', 'adamw'] = 'adam'
        optimizer_args: dict[str, Any] = dc.field(default_factory=dict)
        scheduler: Progress | None = None
        schedule_interval: Literal['step', 'epoch'] = 'epoch'
        schedule_frequency: int = 1

    def __init__(
        self,
        checkpoint_dir: Path | str,
        tokenizer_dir: Path | str | None = None,
        trainer_ckpt_path: Path | str | None = None,
        optim_args: OptimArgs | None = None,
        distribute: Distrubute = 'auto',
        # The following will only be used if the checkpoint directory does not exist.
        model_source: str | Path | Config | None = None,
    ):
        super().__init__()

        if isinstance(checkpoint_dir, str):
            checkpoint_dir = Path(checkpoint_dir)
        if isinstance(tokenizer_dir, str):
            tokenizer_dir = Path(tokenizer_dir)
        if isinstance(trainer_ckpt_path, str):
            trainer_ckpt_path = Path(trainer_ckpt_path)
        
        # Initialize the checkpoint if it does not exists
        checkpoint_dir_ = extend_checkpoint_dir(Path(checkpoint_dir))
        if not checkpoint_dir_.exists():
            if model_source is None:
                raise ValueError(f"Checkpoint {checkpoint_dir} does not exist."
                                 " Please provide the source of the model.")
            if isinstance(model_source, Config):
                if tokenizer_dir is None:
                    raise ValueError("Please provide the tokenizer directory.")
                llm = make_llm(model_source, tokenizer_dir = tokenizer_dir)
            else:
                llm = make_llm(
                    model_source or str(checkpoint_dir),
                    "random",
                    tokenizer_dir = tokenizer_dir,
                )
            llm.save(checkpoint_dir_ := (
                checkpoint_dir
                if checkpoint_dir.parts[0] == 'checkpoints'
                else Path('checkpoints') / checkpoint_dir
            ))
            print(f"Sucessfully initialize checkpoint at {checkpoint_dir_}")

        self.llm = make_llm(
            str(checkpoint_dir),
            tokenizer_dir=tokenizer_dir,
            distribute=distribute,
        )
        self.trainer_ckpt_path = trainer_ckpt_path
        self.optim_args = optim_args or self.OptimArgs()

    @classmethod
    def _get_optimizer_and_scheduler(cls, args: OptimArgs, module: nn.Module | Iterable[nn.Module]):

        # Get the optimizer class based on the specified optimizer type
        optimizer_class = cls._OPTIMIZER_CLASSES.get(args.optimizer)
        if optimizer_class is None:
            raise ValueError(f"Optimizer '{args.optimizer}' is not supported.")

        # Create the optimizer
        if isinstance(module, nn.Module):
            parameters = module.parameters()
        else:
            parameters = itertools.chain(*(m.parameters() for m in module))

        optimizer = optimizer_class(parameters, **args.optimizer_args)

        scheduler_ = args.scheduler
        if scheduler_ is not None:  # If a scheduler is specified, create it
            if isinstance(scheduler_, Progress):  
                scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_)
            else:
                assert False
            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'interval': args.schedule_interval,
                    'frequency': args.schedule_frequency,
                }
            }
        else:  # Return just the optimizer if no scheduler is specified
            return optimizer
    
    @property
    def _trainable_modules(self) -> dict[str, nn.Module]:
        return {'lit_model': self.llm}

    def configure_optimizers(self):
        module = self._trainable_modules.values()
        return self._get_optimizer_and_scheduler(self.optim_args, module)
    
    def save_modules(self, checkpoint: Path | str):
        fabric = self.fabric or self.llm.fabric
        checkpoint = Path(checkpoint)
        checkpoint.mkdir(parents=True, exist_ok=True)
        for name, module in self._trainable_modules.items():
            model_path = (checkpoint / name).with_suffix('.pth')
            if fabric is None:
                torch.save(module.state_dict(), model_path)
            else:
                fabric.save(model_path, module.state_dict())

    def _load_submodule_state(self, *names: str, state_dict: Mapping[str, Any], strict: bool = True) -> _IncompatibleKeys:
        """
        Call "load_state_dict" for each model given its name. This helps since LLM has overriden "load_state_dict".
        """

        missing_keys: list[str] = []
        unexpected_keys: list[str] = []
        for name in names:
            prefix = name + '.'
            model = self.get_submodule(name)
            _dict = {k.removeprefix(prefix): v for k, v in state_dict.items() if k.startswith(prefix)}
            incompatible_keys = model.load_state_dict(_dict, strict=strict)
            missing_keys.extend(prefix + key for key in incompatible_keys.missing_keys)
            unexpected_keys.extend(prefix + key for key in incompatible_keys.unexpected_keys)
        return _IncompatibleKeys(missing_keys, unexpected_keys)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
        names = self._modules.keys()
        if assign:
            raise ValueError(f"The model does not support `assign = True`.")
        incompatible_keys = self._load_submodule_state(
            *names, state_dict=state_dict, strict=strict
        )
        return incompatible_keys


class ValueModel(GPT):

    def __init__(self, config: Config):
        """
        Aapted from litgpt.model.GPT, with the final layer replaced to output a scalar.
        """

        super().__init__(config)
        
        self.lm_head = nn.Linear(config.n_embd, 1, bias=True)
    
    def init_weights(self, source: nn.Module | Mapping[str, Any]):
        """
        initialize weights from a pretrained / fine-tuned GPT.

        Args:
            source (nn.Module | Mapping[str, Any]): the source model (LLM/GPT) or its state dict.
        """

        if isinstance(source, nn.Module):
            state_dict: dict[str, Any] = source.state_dict()
        else:
            state_dict = dict(source)

        state_dict = state_dict.get('model', state_dict)
        transformer_sd = {
            k.removeprefix('transformer.'): v
            for k, v in state_dict.items() if k.startswith('transformer.')
        }
        self.transformer.load_state_dict(transformer_sd)

        std = math.sqrt(2.0 / 5 / self.config.n_embd)
        torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)
        if self.lm_head.bias is not None:
            torch.nn.init.zeros_(self.lm_head.bias)

    def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = super().forward(idx, input_pos)
        x = x.squeeze(2)  # (batch, seq_length)
        return x


def make_vf(source: Path | str | LLM, distribute: Distrubute = 'auto', trained_only: bool = False) :
    if isinstance(source, LLM):
        checkpoint = source.checkpoint_dir
        config = source.model.config
    else:
        checkpoint = extend_checkpoint_dir(Path(source))
        if not checkpoint.exists() and checkpoint.is_dir():
            raise FileNotFoundError(f"{source} is not a valid checkpoint directory.")
        config = Config.from_file(checkpoint / 'model_config.yaml')
    
    # distribute the value model.
    fabric = as_fabric(distribute)
    with fabric.init_module(empty_init=(fabric.world_size > 1)):
        vf = ValueModel(config)
    trained_weights_loaded = False

    if checkpoint is not None:
        vf_path = checkpoint / 'vf.pth'
        if vf_path.exists() and vf_path.is_file():
            fabric.load_raw(vf_path, vf)
            trained_weights_loaded = True
    
    if not trained_weights_loaded:
        if trained_only:
            raise FileNotFoundError("No trained weights found for the value model.")
        if isinstance(source, LLM):
            vf.init_weights(source)  # initialize from the LLM weights.

    return vf
