from abc import ABC, abstractmethod
from bisect import bisect
from numbers import Number

import numpy as np

from .utils import Sched


class Schedule(ABC):
    """Schedule Everything as You Want.

    Supported cases:
    - model.alpha.fill_(value)
    - model.alpha = value
    - model.tau.data[...] = value
    - optim.param_groups[0]["lr"] = value
    - loss["recon_d"]["weight"] = value
    """

    def __init__(self, assigns):
        assert all("value" in _ for _ in assigns)
        self.assigns = assigns
        self.obj_strs = [__class__.parse_target_obj_str(_) for _ in assigns]

    @staticmethod
    def parse_target_obj_str(assign: str):
        return assign.split(".")[0].split("[")[0].strip()

    def __call__(self, step_count, **kwds):
        for i, (obj_str, assign) in enumerate(zip(self.obj_strs, self.assigns)):
            exec(f"{obj_str} = kwds['{obj_str}']")
            value = self.calc_value(i, step_count)
            exec(assign)

    @abstractmethod
    def calc_value(self, i, step_count) -> float: ...


class CosineAnnealing(Schedule):

    def __init__(self, assigns, base_values, min_values, total_step):
        super().__init__(assigns)
        self.base_values = base_values
        self.min_values = min_values
        self.total_step = total_step

    def calc_value(self, i, step_count) -> float:
        base_value = self.base_values[i]
        min_value = self.min_values[i]
        return Sched.cosine(base_value, step_count, self.total_step, min_value)


class CosineAnnealingConstant(Schedule):

    def __init__(self, assigns, base_values, min_values, cos_step):
        super().__init__(assigns)
        self.base_values = base_values
        self.min_values = min_values
        self.cos_step = cos_step

    def calc_value(self, i, step_count) -> float:
        base_value = self.base_values[i]
        min_value = self.min_values[i]
        if step_count < self.cos_step:
            value = Sched.cosine(base_value, step_count, self.cos_step, min_value)
        else:
            value = min_value
        return value


class OscillatoryCosineAnnealingConstant(CosineAnnealingConstant):

    def __init__(self, assigns, p, base_values, min_values, cos_step):
        super().__init__(assigns, base_values, min_values, cos_step)
        self.p = p  # probability to oscilate to min

    def calc_value(self, i, step_count) -> float:
        value = super().calc_value(i, step_count)
        if np.random.rand() < self.p:
            value = self.min_values[i]
        return value


class LinearCosineAnnealing(Schedule):

    def __init__(
        self, assigns, base_values, min_values, warmup_step, total_step, start_factor=0
    ):
        super().__init__(assigns)
        self.base_values = base_values
        self.min_values = min_values
        self.warmup_step = warmup_step
        self.total_step = total_step
        self.start_factor = start_factor

    def calc_value(self, i, step_count) -> float:
        base_value = self.base_values[i]
        min_value = self.min_values[i]
        return Sched.linear_cosine(
            base_value,
            step_count,
            self.warmup_step,
            self.total_step,
            self.start_factor,
            1,
            min_value,
        )


class ConstantCosineAnnealing(Schedule):

    def __init__(
        self, assigns, const_values, base_values, min_values, const_step, total_step
    ):
        super().__init__(assigns)
        self.const_values = const_values
        self.base_values = base_values
        self.min_values = min_values
        self.const_step = const_step
        self.total_step = total_step

    def calc_value(self, i, step_count) -> float:
        base_value = self.base_values[i]
        min_value = self.min_values[i]
        if step_count < self.const_step:
            value = self.const_values[i]
        else:
            value = Sched.cosine(
                base_value, step_count, self.total_step - self.const_step, min_value
            )
        return value


class Squarewave(Schedule):
    """
    e.g., [1, 0] and [0, 500, 1000] means that value is 1 before step 500 while value is 0 after step 500
    """

    def __init__(self, assigns, const_values: list, points: list):
        super().__init__(assigns)
        assert len(const_values) + 1 == len(points)
        assert all(isinstance(_, Number) for _ in const_values)  # not nested
        assert all(isinstance(_, Number) for _ in points)  # not nested
        self.const_values = const_values  # [10, 0.1]
        self.points = points  # [0, total_step // 2, total_step]

    def calc_value(self, i, step_count) -> float:
        idx = bisect(self.points, step_count) - 1
        value = self.const_values[idx]
        # print(step_count, value)
        return value
