from typing import Protocol

import plotly.express as px
import torch as t

from auto_encoder.config import AutoEncoderConfig

FINAL_STRETCH_PROPORTION = 0.85
LAMBDA_WARMUP_PROPORTION = 0.1


class Scheduler(Protocol):
    def step(self) -> None: ...

    def get_lr(self) -> float: ...


class CustomScheduler:
    def __init__(
        self,
        optimizer: t.optim.Optimizer,
        config: AutoEncoderConfig,
        sparsity_lambda_start_val: float = 0,
        lr_warmup_steps_proportion: float = 0.01,
    ):
        self.optimizer = optimizer

        total_steps = config.num_total_steps

        self._auxiliary_balancing_loss_coef = config.auxiliary_balancing_loss_coef
        self._expert_importance_loss_coef = config.expert_importance_loss_coef

        self.capacity_factor = config.capacity_factor

        self.total_steps = total_steps
        self.final_stretch_steps = int(total_steps * FINAL_STRETCH_PROPORTION)
        self.resample_steps = config.resample_steps
        self.lr_warmup_steps = lr_warmup_steps_proportion * total_steps
        self.sparsity_lambda_warmup_steps = int(total_steps * LAMBDA_WARMUP_PROPORTION)

        self.lambda_start = sparsity_lambda_start_val
        self.lambda_end = config.auxiliary_l1_sparsity_coef

        self.original_topk = config.topk
        self.topk_temperature = config.stochastic_topk_temperature

        self.original_topm = config.topm

        self.min_topk_temperature = config.min_topk_temperature

        self.step_count = 0

    def step(self):
        self.step_count += 1

        lr = self.get_lr()

        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    @property
    def declining_factor(self) -> float:
        return 1 - (self.step_count - self.final_stretch_steps) / (
            self.total_steps - self.final_stretch_steps
        )

    def get_lr(self) -> float:
        """Learning rate starts at 0 and linearly increases to the learning rate set in the Adam optimizer.
        After the warmup period, the learning rate is constant until either a resampling event where it acts
        like it did at the start or until the last 20% of the training steps.
        For the last 15% the learning rate linearly decreases to 0.
        """
        ending_decline_lr = self.optimizer.defaults["lr"] * self.declining_factor

        current_max_lr = min(self.optimizer.defaults["lr"], ending_decline_lr)

        if self.step_count % self.resample_steps <= self.lr_warmup_steps:
            warmup_lr = (
                self.step_count % self.resample_steps / self.lr_warmup_steps
            ) * self.optimizer.defaults["lr"]
            lr = min(warmup_lr, current_max_lr)

        elif self.step_count >= self.final_stretch_steps:
            lr = ending_decline_lr

        else:
            lr = self.optimizer.param_groups[0]["lr"]

        return lr

    @property
    def sparsity_lambda(self) -> float:
        steps_since_resample = self.step_count % self.resample_steps
        # if steps_since_resample <= self.sparsity_lambda_warmup_steps / 2:
        #     sparsity_lambda = 0
        # elif steps_since_resample <= self.sparsity_lambda_warmup_steps:
        if steps_since_resample <= self.sparsity_lambda_warmup_steps:
            # Linearly increase lambda from lambda_start to lambda_end over the warmup period
            sparsity_lambda = self.lambda_start + (
                steps_since_resample / self.sparsity_lambda_warmup_steps
            ) * (self.lambda_end - self.lambda_start)
        else:
            sparsity_lambda = self.lambda_end

        return sparsity_lambda

    @property
    def balancing_loss_coefs(self) -> tuple[float, float]:
        # Double the balancing losses at the start to balance the router out.
        # Then linearly decrease balancing loss coefs from start values to 0 over the end period
        if self.step_count < self.sparsity_lambda_warmup_steps:
            return (
                self._auxiliary_balancing_loss_coef * 2,
                self._expert_importance_loss_coef * 2,
            )

        auxiliary_balancing_loss_coef = self._auxiliary_balancing_loss_coef * min(
            self.declining_factor, 1
        )
        expert_importance_loss_coef = self._expert_importance_loss_coef * min(
            1, self.declining_factor
        )

        return auxiliary_balancing_loss_coef, expert_importance_loss_coef

    @property
    def get_capacity_factor(self) -> float:
        """Starts at 1 + capacity_factor and linearly decreases to capacity_factor over the length of training"""
        return (self.total_steps - self.step_count / self.total_steps) + self.capacity_factor


class TopK_M_DecreasingScheduler(CustomScheduler):
    def get_lr(self) -> float:
        """Trapezoid LR schedule with warmup, constant, and declining periods."""
        ending_decline_lr = self.optimizer.defaults["lr"] * self.declining_factor

        current_max_lr = min(self.optimizer.defaults["lr"], ending_decline_lr)

        if self.step_count <= self.lr_warmup_steps:
            warmup_lr = (self.step_count / self.lr_warmup_steps) * self.optimizer.defaults[
                "lr"
            ]
            lr = min(warmup_lr, current_max_lr)

        elif self.step_count >= self.final_stretch_steps:
            lr = ending_decline_lr

        else:
            lr = self.optimizer.param_groups[0]["lr"]

        return lr

    @property
    def sparsity_lambda(self) -> float:
        lambda_soft_end = self.lambda_end * 0.01
        lambda_soft_mid_right = self.lambda_end * 0.1
        lambda_soft_mid_left = self.lambda_end * 0.2

        if self.step_count >= self.final_stretch_steps:
            sparsity_lambda = 0.0
            return sparsity_lambda

        elif self.step_count >= self.sparsity_lambda_warmup_steps:
            # Linearly decrease lambda from lambda_soft_mid_right to lambda_soft_end after the warmup period
            sparsity_lambda = lambda_soft_mid_right + (self.step_count / self.total_steps) * (
                lambda_soft_end - lambda_soft_mid_right
            )

            return sparsity_lambda

        else:
            # Linearly increase lambda from lambda_start to lambda_end over the warmup period
            sparsity_lambda = self.lambda_start + (
                self.step_count / self.sparsity_lambda_warmup_steps
            ) * (lambda_soft_mid_left - self.lambda_start)

            return sparsity_lambda

    @property
    def topk(self) -> int:
        # Linearly decrease topk from 2k to k over the first 50% of training.
        # Then linearly decrease topk from k to 0.9k over the final stretch.

        if self.step_count >= self.final_stretch_steps:
            steps_since_final_stretch = self.step_count - self.final_stretch_steps
            num_final_stretch_steps = self.total_steps - self.final_stretch_steps

            topk = self.original_topk - (
                steps_since_final_stretch / num_final_stretch_steps
            ) * (self.original_topk - self.original_topk * 0.9)

            return int(topk)

        elif self.step_count >= self.total_steps / 2:
            return self.original_topk

        else:
            halfway_steps = self.total_steps / 2
            topk = self.original_topk * 2.0 - (self.step_count / halfway_steps) * (
                self.original_topk * 2.0 - self.original_topk
            )

            return int(topk)

    @property
    def topm(self) -> int:
        # Linearly decrease topk from 20k to 2k over the first 5% of training. Then 2k to 1k up to 50% of training.
        # Then linearly decrease topk from k to 0.9k over the final stretch.

        if self.step_count >= self.final_stretch_steps:
            steps_since_final_stretch = self.step_count - self.final_stretch_steps
            num_final_stretch_steps = self.total_steps - self.final_stretch_steps

            topm = self.original_topm - (
                steps_since_final_stretch / num_final_stretch_steps
            ) * (self.original_topm - self.original_topm * 0.9)

            return int(topm)

        elif self.step_count >= self.total_steps / 2:
            return self.original_topm

        else:
            halfway_steps = self.total_steps / 2
            topm = self.original_topm * 2.0 - (self.step_count / halfway_steps) * (
                self.original_topm * 2.0 - self.original_topm
            )

            return int(topm)

            # elif self.step_count <= self.total_steps * 0.2:
            #     early_switch_steps = self.total_steps * 0.2

            #     topm = self.original_topm * 20.0 - (self.step_count / early_switch_steps) * (
            #         self.original_topm * 20.0 - self.original_topm * 2.0
            #     )

            #     return int(topm)

            # else:
            #     halfway_steps = self.total_steps / 2
            #     early_switch_steps = self.total_steps * 0.2

            #     steps_proportion = (self.step_count - early_switch_steps) / (
            #         halfway_steps - early_switch_steps
            #     )

            #     topm = self.original_topm * 2.0 - steps_proportion * (
            #         self.original_topm * 2.0 - self.original_topm
            #     )

            return int(topm)

    @property
    def stochastic_topk_temperature(self) -> float:
        # Linearly decrease temperature from 1 to 0.1 over the first the training period.

        temp = self.topk_temperature - (self.step_count / self.total_steps) * (
            self.topk_temperature - self.min_topk_temperature
        )

        return temp


if __name__ == "__main__":
    from dataclasses import dataclass

    import pandas as pd
    import torch as t
    import torch.optim as optim
    from collectibles import ListCollection

    @dataclass
    class SchedulerParams:
        lr: float
        l1_sparsity_coef: float
        auxiliary_balancing_loss_coef: float
        expert_importance_loss_coef: float

        topk: int
        stochastic_topk_temperature: float

    class ParamsCollection(ListCollection[SchedulerParams]):
        lr: list[float]
        l1_sparsity_coef: list[float]
        auxiliary_balancing_loss_coef: list[float]
        expert_importance_loss_coef: list[float]

        topk: list[int]
        stochastic_topk_temperature: list[float]
        topk_cutoff_prob: list[float]

    TOTAL_STEPS = 4_000

    optimizer = optim.Adam([t.tensor([1.0])], lr=1e-3)

    config = AutoEncoderConfig(
        num_total_steps=TOTAL_STEPS,
        resample_steps=1000,
        auxiliary_balancing_loss_coef=0.1,
        expert_importance_loss_coef=0.2,
        auxiliary_l1_sparsity_coef=0.5,
        capacity_factor=1.5,
    )

    scheduler = TopK_M_DecreasingScheduler(
        optimizer,
        config,
    )

    params_list = []
    for _ in range(TOTAL_STEPS):
        scheduler.step()
        params_list.append(
            SchedulerParams(
                lr=scheduler.get_lr(),
                l1_sparsity_coef=scheduler.sparsity_lambda,
                auxiliary_balancing_loss_coef=scheduler.balancing_loss_coefs[0],
                expert_importance_loss_coef=scheduler.balancing_loss_coefs[1],
                topk=scheduler.topk,
                stochastic_topk_temperature=scheduler.stochastic_topk_temperature,
            )
        )

    params_collection = ParamsCollection(params_list)
    df = pd.DataFrame(
        {
            "lr": params_collection.lr,
            "l1_sparsity_coef": params_collection.l1_sparsity_coef,
            "auxiliary_balancing_loss_coef": params_collection.auxiliary_balancing_loss_coef,
            "expert_importance_loss_coef": params_collection.expert_importance_loss_coef,
            "topk": params_collection.topk,
            "stochastic_topk_temperature": params_collection.stochastic_topk_temperature,
        }
    )
    df.to_csv("soft_ae_scheduler_params.csv", index=False)

    px.line(df).show()
