from abc import abstractmethod
from typing import List, Dict, Any

import numpy as np

from ..defs import Sparsity


__all__ = [
    'SparsitySchedule',
    'ConstantSparsitySchedule',
    'PolynomialSparsitySchedule',
    'SineSparsitySchedule',
    'ListedSparsitySchedule',
    'create_sparsity_schedule'
]


# TODO add epochwise treatment


class SparsitySchedule:
    '''
    Base class for pruning schedules
    '''
    def __init__(self, step: int = 0, *args, **kwargs) -> None:
        '''
        Args:
            step: current optimization step
        '''
        self.current_step = step

    @abstractmethod
    def get_current_sparsity(self) -> float:
        pass

    def step(self):
        self.current_step += 1


class ConstantSparsitySchedule(SparsitySchedule):
    '''
    Applies constant sparsity after init_step.
    '''
    def __init__(
        self, 
        sparsity: Sparsity,
        init_step: int = 0,
        step: int = 0, 
    ) -> None:
        '''
        Args:
            init_sparsity: sparsity imposed at 0th step
            final_sparsity: sparsity imposed at the last update step
            init_step: step from which sparsity is applied
            final_step: last step to apply sparsity 
            power: the power in the interpolation law
            step: current optimization step
        '''
        # sanity checks
        assert 0 <= sparsity <= 1.0
        assert 0 <= init_step
        super().__init__(step)
        self.sparsity = sparsity
        self.init_step = init_step

    def get_current_sparsity(self) -> Sparsity:
        if self.current_step < self.init_step:
            return 0
        else:
            return self.sparsity


class PolynomialSparsitySchedule(SparsitySchedule):
    '''
    Implements following sparsity schedule:
    s_t = s_f + (s_i - s_f) * (1 - (t - t_i) / (t_f - t_i))^{p}
    '''
    def __init__(
        self, 
        init_sparsity: float,
        final_sparsity: float,
        init_step: int = 0,
        final_step: int = -1,
        power: float = 1.0,
        step: int = 0, 
    ) -> None:
        '''
        Args:
            init_sparsity: sparsity imposed at 0th step
            final_sparsity: sparsity imposed at the last update step
            init_step: step from which sparsity is applied
            final_step: last step to apply sparsity 
            power: the power in the interpolation law
            step: current optimization step
        '''
        # sanity checks
        assert 0 <= init_sparsity <= final_sparsity <= 1.0
        assert 0 <= init_step < final_step
        assert power > 0
        super().__init__(step)
        self.init_sparsity = init_sparsity
        self.final_sparsity = final_sparsity
        self.init_step = init_step
        self.final_step = final_step
        self.power = power

    def get_current_sparsity(self) -> float:
        t_i, t_f, t = self.init_step, self.final_step, self.current_step
        s_i, s_f = self.init_sparsity, self.final_sparsity
        x = np.clip((t - t_i) / (t_f - t_i), 0.0, 1.0)
        return s_f + (s_i - s_f) * (1 - x) ** self.power

    
class SineSparsitySchedule(SparsitySchedule):
    '''
    Implements following sparsity schedule:
    s_t = s_i + (s_f - s_i) \sin ((\pi / 2) * (t - t_i) /  (t_f - t_i))
    '''
    def __init__(
        self, 
        init_sparsity: float,
        final_sparsity: float,
        init_step: int = 0,
        final_step: int = -1,
        step: int = 0, 
    ) -> None:
        '''
        Args:
            init_sparsity: sparsity imposed at 0th step
            final_sparsity: sparsity imposed at the last update step
            init_step: step from which sparsity is applied
            final_step: last step to apply sparsity 
            power: the power in the interpolation law
            step: current optimization step
        '''
        # sanity checks
        assert 0 <= init_sparsity <= final_sparsity <= 1.0
        assert 0 <= init_step < final_step
        super().__init__(step)
        self.init_step = init_step
        self.final_step = final_step

    def get_current_sparsity(self) -> float:
        t_i, t_f, t = self.init_step, self.final_step, self.current_step
        s_i, s_f = self.init_sparsity, self.final_sparsity
        x = np.clip((t - t_i) / (t_f - t_i), 0.0, 1.0)
        return s_i + (s_f - s_i) * np.sin(np.pi * x / 2)


class ListedSparsitySchedule(SparsitySchedule):
    '''

    
    '''
    def __init__(
        self, 
        sparsities: List[float],
        init_step: int = 0,
        final_step: int = -1, 
        step: int = 0, 
    ) -> None:
        '''
        Args:
            sparsities: sparsity imposed at 0th step
            init_step: step from which sparsity is applied
            final_step: last step to apply sparsity 
            step: current optimization step
        '''
        # sanity checks
        assert 0 <= init_step < final_step
        super().__init__(step)
        self.init_step = init_step
        self.final_step = final_step
        self.sparsities = sparsities

    def get_current_sparsity(self) -> float:
        t_i, t_f, t = self.init_step, self.final_step, self.current_step
        current_id = np.floor((np.clip(t, t_i, t_f) - t_i) / (t_f - t_i) * (len(self.sparsities) - 1)).astype(np.int16)
        # tmp = (np.clip(t, t_i, t_f) - t_i) / (t_f - t_i) * (len(self.sparsities) - 1)
        # print("tmp", tmp)
        return self.sparsities[current_id]


SCHEDULE_REGISTRY = {
    'constant': ConstantSparsitySchedule,
    'polynomial': PolynomialSparsitySchedule,
    'sine': SineSparsitySchedule,
    'listed': ListedSparsitySchedule
}


def create_sparsity_schedule(schedule_class: str, schedule_kwargs: Dict[str, Any]):
    return SCHEDULE_REGISTRY[schedule_class](**schedule_kwargs)
