from src.methods.svgd import SVGD
from numpyro.util import fori_collect, ravel_pytree
from numpyro.contrib.einstein.steinvi import SteinVIRunResult, SteinVI
from numpyro.infer.autoguide import AutoDelta

def cyclical_annealing(num_steps: int, num_cycles: int, trans_speed: int):
    """ Cyclical annealing schedule as in eq. 4 of [1].
    
    **References** (MLA)
    Annealed Stein Variational Gradient Descent. 2021. Francesco D’Angelo and Vincent Fortuin.

    :param num_steps: The total number of steps. Corresponds to $T$ in eq. 4 of [1].
    :param num_cycles: The total number of cycles. Corresponds to $C$ in eq. 4 of [1].
    :param trans_speed: Speed of transition between two phases. Corresponds to $p$ in eq. 4 of [1].
    """
    norm = float(num_steps + 1) / float(num_cycles)
    cycle_len = num_steps // num_cycles
    last_start = (num_cycles-1) * cycle_len

    def cycle_fn(t):
        last_cycle = t // last_start
        return (1-last_cycle) * (((t % cycle_len) + 1) / norm) ** trans_speed + last_cycle

    return cycle_fn


class ASVGD(SVGD):
    def setup_run(
        self,
        rng_key,
        num_steps,
        args,
        init_state,
        kwargs):

        cyc_fn = cyclical_annealing(num_steps, num_cycles=10, trans_speed=10)

        istep, idiag, icol, iext, iinit, = super().setup_run(
            rng_key,
            num_steps,
            args,
            init_state,
            kwargs,
        )

        def step(info):
            t, iinfo = info[0], info[-1]
            self.loss_temperature = cyc_fn(t) / float(self.num_stein_particles)
            return (t+1, istep(iinfo))
        
        def diagnostic(info):
            _, iinfo = info
            return idiag(iinfo)

        def collect(info):
            _, iinfo = info
            return icol(iinfo)
        
        def extract_state(info):
            _, iinfo = info
            return iext(iinfo)

        info_init = (0, iinit)
        return step, diagnostic, collect, extract_state, info_init
