from typing import Literal, Iterable
import copy
import pytorch_lightning as pl
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

from symo.optim2 import Symo
from symo.nsymo import NSymo
from typing import Any

from symo.utils import InverseStepScheduler
from symo.nanogpt import GPT, GPTConfig, symo_group_spec_v2, symo_filtered_spec


NDArray = torch.Tensor


def select_optimizer(name: str):
    return {
        "adam_muon": configure_adam_muon,
        "adam_symo": configure_adam_symo,
        "adam_nsymo": configure_adam_nsymo,
        "symo": configure_symo,
        "adam": configure_adam,
    }[name]


def select_scheduler(name: str):
    return {
        "warmup_cosine": configure_warmup_cosine_scheduler,
        "linear": configure_linear_scheduler,
        "inverse": configure_inv_scheduler,
        "exp": configure_exp_scheduler,
    }[name]


class NanoGPTLitModule(pl.LightningModule):
    def __init__(
        self,
        vocab_size: int = 50304,
        block_size: int = 1024,
        n_layer: int = 12,
        n_head: int = 12,
        n_embd: int = 768,
        dropout: float = 0.0,
        bias: bool = True,
        optimizers: dict[str, Any] | None = None,
        # weight_decay: float = 0.1,
        gradient_clip_val: float | None = None,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        self.gradient_clip_val = gradient_clip_val

        config = GPTConfig(
            vocab_size=vocab_size,
            block_size=block_size,
            n_layer=n_layer,
            n_head=n_head,
            n_embd=n_embd,
            dropout=dropout,
            bias=bias,
        )
        self.model = GPT(config)

    def forward(self, idx, targets=None):
        return self.model(idx, targets)

    def training_step(self, batch, batch_idx):
        opts = self.optimizers()

        opts = opts if isinstance(opts, Iterable) else [opts]

        for opt in opts:
            opt.zero_grad()

        idx, targets = batch
        _, loss = self(idx, targets)

        self.manual_backward(loss)

        for opt in opts:
            self.clip_gradients(
                opt,
                gradient_clip_val=self.gradient_clip_val,
                gradient_clip_algorithm="norm",
            )
            opt.step()

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def _step_schedulers(self, interval: Literal["epoch", "step"]):
        sch_configs = self.trainer.lr_scheduler_configs

        for config in sch_configs:
            if config.interval != interval:
                continue

            current_step = (
                self.current_epoch if interval == "epoch" else self.global_step
            )

            if (current_step + 1) % config.frequency != 0:
                continue

            scheduler = config.scheduler
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                if config.monitor:
                    metric = self.trainer.callback_metrics.get(config.monitor)
                    if metric is not None:
                        config.scheduler.step(metric)
            else:
                scheduler.step()

    def on_train_epoch_end(self):
        super().on_train_epoch_end()
        self._step_schedulers(interval="epoch")

    def on_train_batch_end(self, outputs, batch, batch_idx):
        super().on_train_batch_end(outputs, batch, batch_idx)
        self._step_schedulers(interval="step")

    def validation_step(self, batch, batch_idx: int):
        idx, targets = batch
        logits, loss = self(idx, targets)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizers = copy.deepcopy(self.hparams.optimizers)

        keys = list(optimizers.keys())
        keys = sorted(keys)
        key = "_".join(keys)

        configure_optimizers = select_optimizer(key)
        out = configure_optimizers(self.model, **optimizers)
        return out


def configure_inv_scheduler(optimizer, **kwargs):
    return InverseStepScheduler(optimizer, **kwargs)


def configure_exp_scheduler(optimizer, **kwargs):
    return optim.lr_scheduler.ExponentialLR(optimizer, **kwargs)


def configure_linear_scheduler(optimizer, **kwargs):
    return optim.lr_scheduler.LinearLR(optimizer, **kwargs)


def configure_warmup_cosine_scheduler(
    optimizer,
    warmup_steps: int,
    total_steps: int,
    min_lr: float | NDArray,
):
    warmup_scheduler = LinearLR(
        optimizer,
        start_factor=1.0 / (warmup_steps + 1),
        end_factor=1.0,
        total_iters=warmup_steps,
    )

    cosine_scheduler = CosineAnnealingLR(
        optimizer,
        T_max=total_steps - warmup_steps,
        eta_min=min_lr,
    )

    scheduler = SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[warmup_steps],
    )

    return scheduler


def configure_adam_symo(model, adam: dict[str, Any], symo: dict[str, Any]):
    adam = copy.deepcopy(adam)
    symo = copy.deepcopy(symo)

    split = symo.pop("split", None)
    symo_params, adam_params = (
        muon_split_parameters(model)
        if split == "muon"
        else symo_split_parameters(model)
    )
    adam_lr_config = adam.pop("lr_scheduler", {})
    symo_lr_config = symo.pop("lr_scheduler", {})

    groups_spec = symo.pop("groups_spec", {})
    spec = symo_group_spec_v2(model, **groups_spec)

    named_params = list(model.named_parameters())
    real_spec = symo_filtered_spec(spec, named_params, symo_params)

    symo_params = [p for _, p in symo_params]
    opt_symo = Symo(symo_params, real_spec, **symo)
    opt_adam = optim.Adam(adam_params, **adam)

    adam_lr = configure_lr_scheduler(opt_adam, adam_lr_config)
    symo_lr = configure_lr_scheduler(opt_symo, symo_lr_config)

    adam_config = {"optimizer": opt_adam} | adam_lr
    symo_config = {"optimizer": opt_symo} | symo_lr
    return [adam_config, symo_config]


def configure_adam_nsymo(model, adam: dict[str, Any], nsymo: dict[str, Any]):
    adam = copy.deepcopy(adam)
    symo = copy.deepcopy(nsymo)

    split = symo.pop("split", None)
    symo_params, adam_params = (
        muon_split_parameters(model)
        if split == "muon"
        else symo_split_parameters(model)
    )
    adam_lr_config = adam.pop("lr_scheduler", {})
    symo_lr_config = symo.pop("lr_scheduler", {})

    groups_spec = symo.pop("groups_spec", {})
    spec = symo_group_spec_v2(model, **groups_spec)

    named_params = list(model.named_parameters())
    real_spec = symo_filtered_spec(spec, named_params, symo_params)

    symo_params = [p for _, p in symo_params]
    opt_symo = NSymo(symo_params, real_spec, **symo)
    opt_adam = optim.Adam(adam_params, **adam)

    adam_lr = configure_lr_scheduler(opt_adam, adam_lr_config)
    symo_lr = configure_lr_scheduler(opt_symo, symo_lr_config)

    adam_config = {"optimizer": opt_adam} | adam_lr
    symo_config = {"optimizer": opt_symo} | symo_lr
    return [adam_config, symo_config]


def configure_adam_muon(model, adam: dict[str, Any], muon: dict[str, Any]):
    adam = copy.deepcopy(adam)
    muon = copy.deepcopy(muon)

    split = muon.pop("split", None)
    muon_params, adam_params = (
        muon_split_parameters(model)
        if split == "muon"
        else symo_split_parameters(model)
    )
    adam_lr_config = adam.pop("lr_scheduler", {})
    muon_lr_config = muon.pop("lr_scheduler", {})

    opt_muon = optim.Muon(muon_params, **muon)
    opt_adam = optim.Adam(adam_params, **adam)

    adam_lr = configure_lr_scheduler(opt_adam, adam_lr_config)
    muon_lr = configure_lr_scheduler(opt_muon, muon_lr_config)

    adam_config = {"optimizer": opt_adam} | adam_lr
    muon_config = {"optimizer": opt_muon} | muon_lr
    return [adam_config, muon_config]


def configure_adam(model, adam: dict[str, Any]):
    adam = copy.deepcopy(adam)
    params = tuple(model.parameters())
    adam_lr_config = adam.pop("lr_scheduler", {})
    opt = optim.Adam(params, **adam)
    adam_lr = configure_lr_scheduler(opt, adam_lr_config)
    return {"optimizer": opt} | adam_lr


def configure_symo(model, symo: dict[str, Any]) -> dict:
    symo = copy.deepcopy(symo)
    groups_spec = symo.pop("groups_spec", {})
    spec = symo_group_spec_v2(model, **groups_spec)
    params = tuple(model.parameters())

    symo_lr_config = symo.pop("lr_scheduler", {})
    opt = Symo(params, groups_spec=spec, **symo)
    symo_lr = configure_lr_scheduler(opt, symo_lr_config)
    return {"optimizer": opt} | symo_lr


def configure_lr_scheduler(
    optimizer: optim.Optimizer,
    config: dict[str, Any] | None,
):
    if not config:
        return {}

    name = config.pop("name", None)
    assert name is not None

    frequency = config.pop("frequency", None)
    interval = config.pop("interval", None)
    monitor = config.pop("monitor", None)

    configure_scheduler = select_scheduler(name)
    scheduler = configure_scheduler(optimizer, **config)
    out = {"scheduler": scheduler}
    out |= {} if frequency is None else {"frequency": frequency}
    out |= {} if interval is None else {"interval": interval}
    out |= {} if monitor is None else {"monitor": monitor}
    return {"lr_scheduler": out}


def muon_split_parameters(model):
    other_keys = [
        "transformer.wte.weight",
        "transformer.wpe.weight",
        "transformer.lm_head.weight",
    ]

    def other_cond(n, p):
        if n in other_keys or p.dim() < 2:
            return True
        return False

    muon_params = [(n, p) for n, p in model.named_parameters() if not other_cond(n, p)]
    other_params = [(n, p) for n, p in model.named_parameters() if other_cond(n, p)]

    return muon_params, other_params


def symo_split_parameters(model):
    other_keys = [
        "transformer.wte.weight",
        "transformer.wpe.weight",
        "transformer.lm_head.weight",
    ]

    def other_cond(n, p):
        if n in other_keys:
            return True
        return False

    symo_params = [(n, p) for n, p in model.named_parameters() if not other_cond(n, p)]
    other_params = [(n, p) for n, p in model.named_parameters() if other_cond(n, p)]

    return symo_params, other_params


if __name__ == "__main__":
    from symo.data import ShakespeareDataModule

    datamodule = ShakespeareDataModule()
    datamodule.prepare_data()
    vocab_size = datamodule.vocab_size
    model = NanoGPTLitModule(
        vocab_size=vocab_size,
        n_layer=6,
        n_embd=384,
        optimizers=dict(
            adam=dict(
                lr=1e-3,
                lr_scheduler=dict(
                    name="inverse",
                ),
            ),
            symo=dict(
                lr=1e-3,
            ),
        ),
    )

    trainer = pl.Trainer(max_epochs=10, accelerator="cpu")
    trainer.fit(model, datamodule)
    print()
