from typing import Any, Callable, Dict, List

import numpy as np  # type: ignore
import torch
from torch import nn
from torch.optim import Optimizer

T = torch.Tensor


class MCMCOptim:
    def __init__(self) -> None:
        self.tune_params = {
            'delta': 0.65,
            't0': 10,
            'gamma': .05,
            'kappa': .75,
            'mu': 0.,
            'H': 0,
            'log_eps': 1.
            # 'mu': np.log(self.param_groups[0]["step_size"]),
        }

        self.param_groups: List[Dict[str, Any]]

        # print(f"@MCMC_Optim {self.tune_params=}")
        # exit()

    def tune(self, accepts: List[bool]) -> None:
        """original code quoted PyMC as using these values in a switch statement"""
        avg_acc = sum(accepts) / len(accepts)
        switcher = {
            0.001: 0.1,
            0.05: 0.5,
            0.20: 0.5,
            0.99: 10.0,
            0.75: 2.0,
            0.5: 1.1
        }

        scale = switcher.get(avg_acc, 0.9)
        for group in self.param_groups:
            group['step_size'] *= scale
            # print(f'{avg_acc=:.3f} & {scale=} -> {group["lr"]=:.3f}')

    def dual_average_tune(self, accepts: List[bool], t: float, alpha: float) -> None:
        """
        NUTS Sampler p.17 : Algorithm 5
        mu = log(10 * initial_step_size)
        H_m : running difference between target acceptance rate and current acceptance rate
            delta : target acceptance rate
            alpha : (singular) current acceptance rate
        log_eps = mu - t**0.5 / gamma H_m
        running_log_eps = t**(-kappa) log_eps + (1 - t**(-kappa)) running_log_eps
        """

        # accept_ratio = sum(accepts)/len(accepts)
        assert 0 <= alpha <= 1., f'{alpha=}'
        # print(f"{alpha=}")
        # print(t)

        delta, t0, gamma, kappa, mu, H, log_eps = self.tune_params.values()
        # t = len(chain)
        # alpha = sum(accepts)/len(accepts)
        # print(f'{self.param_groups[0]["step_size"]} {alpha}')
        print(f"{delta=} {alpha=}")

        H = (1 - 1 / (t + t0)) * H + 1 / (t + t0) * (delta - alpha)
        # H = 0.5 * H + 0.5 * (delta - alpha)

        log_eps_t = mu - t ** 0.5 / gamma * H
        # log_eps_t = mu - 1 * (delta - alpha)

        log_eps = t**(-kappa) * log_eps_t + (1 - t**(-kappa)) * log_eps
        # log_eps = 0.5 * log_eps_t + 0.5 * log_eps

        self.tune_params["H"] = H
        self.tune_params["log_eps"] = log_eps
        # print(f"{log_eps=} {self.tune_params['log_eps']=}")
        # exit()

        for group in self.param_groups:
            group["step_size"] = np.exp(log_eps)


class HMCOptim(Optimizer, MCMCOptim):
    def __init__(self, model: nn.Module, step_size: float = 0.1, prior_std: float = 1.):
        '''
        log_N(θ|0,1) =
        :param model:
        :param step_size:
        :param norm_sigma:
        :param addnoise:
        '''

        weight_decay = 1 / (prior_std ** 2) if prior_std != 0 else 0
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if step_size < 0.0:
            raise ValueError("Invalid learning rate: {}".format(step_size))

        defaults = dict(step_size=step_size, weight_decay=weight_decay, traj_step=0)

        self.model = model
        params = self.model.parameters()

        super(HMCOptim, self).__init__(params, defaults)  # type: ignore
        MCMCOptim.__init__(self)

    def step(self) -> None:  # type: ignore
        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]  # contains state['velocity']

                state['velocity'].add_(-group['step_size'] * grad)
                p.data.add_(state['velocity'], alpha=group['step_size'])

            group['traj_step'] += 1

    def sample_momentum(self) -> None:
        for group in self.param_groups:
            group['traj_step'] = 0
            for p in group['params']:
                # print(p)
                state = self.state[p]
                state['velocity'] = 1.0 * torch.randn_like(p)

    def leapfrog_step(self, closure: Callable[..., T]) -> T:
        """
        Leapfrog Integrator can be implemented with closure:
            1) Takes data and computes gradient
            2) moves halfway along the gradient
            3) recomputes the gradient after half step and does another half-step
            4) voila, we're at new sample
        TODO:: let log_prob return data and target as well because for closure we need the same data: https://pytorch.org/docs/stable/optim.html#taking-an-optimization-step
        TODO:: but that's a requirement of prob_model
        """

        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data
                # grad.clamp_(-1000,1000)
                state = self.state[p]  # contains state['velocity']

                state['velocity'].add_(other=-0.5 * group['step_size'] * grad)
                p.data.add_(other=state['velocity'], alpha=group['step_size'])

        log_prob = closure()
        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data
                # grad.clamp_(-1000, 1000)
                state = self.state[p]  # contains state['velocity']
                state['velocity'].add_(other=-0.5 * group['step_size'] * grad)
                # p.data.add_(other=state['velocity'], alpha=group['step_size'])

        return log_prob
