import abc
import torch
import numpy as np
from nfmc_jax.DLA.posterior import DifferentiableTemperedPosterior
from nfmc_jax.flows.base import TorchFlowInterface


class ParticleOptimizer(abc.ABC):
    def __init__(self, particles: torch.Tensor, lr: float = 1e-2, **kwargs):
        """
        Like a neural network optimizer, but for particles.
        """
        self.particles = particles
        self.lr = lr
        self.step_counter: int = 0

    @abc.abstractmethod
    def step(self, grad: torch.Tensor, **kwargs):
        raise NotImplementedError

    @abc.abstractmethod
    def reset(self):
        raise NotImplementedError


class ParticleGradientDescent(ParticleOptimizer):
    def __init__(self, particles: torch.Tensor, momentum: float = 0.0, **kwargs):
        super().__init__(particles, **kwargs)
        self.momentum = momentum
        self.grad_cache = None

    @torch.no_grad()
    def step(self, grad: torch.Tensor, **kwargs):
        assert grad.shape == self.particles.shape
        grad_norm = torch.linalg.norm(grad, dim=1).reshape(-1, 1)

        if self.step_counter > 0:
            gradient_movement = (1 - self.momentum) * (-self.lr * grad / grad_norm)
            momentum_movement = self.momentum * (-self.lr * self.grad_cache / grad_norm)
            update_values = gradient_movement + momentum_movement
        else:
            update_values = -self.lr * grad / grad_norm

        self.particles += update_values

        self.step_counter += 1
        self.grad_cache = grad

    def reset(self):
        self.step_counter = 0
        self.grad_cache = None

class ParticleAdagrad(ParticleOptimizer):
    def __init__(self,
                 particles: torch.Tensor,
                 grad_sq: torch.Tensor,
                 lr: float = 1e-2,
                 eps: float = 1e-8,
                 **kwargs):
        super().__init__(particles, **kwargs)
        self.original_lr = lr
        self.lr = lr
        self.eps = eps

        self.grad_sq = grad_sq
        self.step_counter: int = 0

    @torch.no_grad()
    def step(self, grad: torch.Tensor, **kwargs):
        self.step_counter += 1
        self.grad_sq += grad ** 2
        self.particles -= self.lr * grad / torch.sqrt(self.grad_sq + self.eps)

        return self.particles

    def reset(self):
        self.lr = self.original_lr
        self.grad_sq = torch.zeros_like(self.particles)
        self.step_counter = 0

class ParticleRMSProp(ParticleOptimizer):
    def __init__(self,
                 particles: torch.Tensor,
                 avg_grad: torch.Tensor,
                 lr: float = 1e-2,
                 eps: float = 1e-8,
                 gamma: float = 0.9,
                 **kwargs):
        super().__init__(particles, **kwargs)
        self.original_lr = lr
        self.lr = lr
        self.eps = eps

        self.rms_grad = gamma * avg_grad
        self.gamma = gamma
        self.step_counter: int = 0

    @torch.no_grad()
    def step(self, grad: torch.Tensor, **kwargs):
        # Algorithm 1 in https://arxiv.org/abs/1412.6980
        self.step_counter += 1
        self.rms_grad += (1.0 - self.gamma) * grad ** 2
        self.particles -= self.lr * grad / torch.sqrt(self.rms_grad + self.eps)

    def reset(self):
        self.lr = self.original_lr
        self.sv = torch.zeros_like(self.particles)
        self.step_counter = 0

class ParticleAdam(ParticleOptimizer):
    def __init__(self,
                 particles: torch.Tensor,
                 lr: float = 1e-2,
                 beta1: float = 0.99,
                 beta2: float = 0.999,
                 eps: float = 1e-8,
                 **kwargs):
        super().__init__(particles, **kwargs)
        self.original_lr = lr
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps

        self.m = torch.zeros_like(self.particles)
        self.v = torch.zeros_like(self.particles)
        self.step_counter: int = 0

    @torch.no_grad()
    def step(self, grad: torch.Tensor, **kwargs):
        # Algorithm 1 in https://arxiv.org/abs/1412.6980
        self.step_counter += 1
        self.m = self.beta1 * self.m + (1 - self.beta1) * grad
        self.v = self.beta2 * self.v + (1 - self.beta2) * (grad ** 2)
        m_corrected = self.m / (1 - self.beta1 ** self.step_counter)
        v_corrected = self.v / (1 - self.beta2 ** self.step_counter)
        self.particles -= self.lr * m_corrected / (torch.sqrt(v_corrected) + self.eps)

    def reset(self):
        self.lr = self.original_lr
        self.m = torch.zeros_like(self.particles)
        self.v = torch.zeros_like(self.particles)
        self.step_counter = 0


class ParticleLineSearch(ParticleOptimizer):
    def __init__(self,
                 particles: torch.Tensor,
                 accumulated_grad: torch.Tensor,
                 particle_logp: torch.Tensor,
                 particle_logq: torch.Tensor,
                 posterior: DifferentiableTemperedPosterior,
                 interface: TorchFlowInterface,
                 logp_func: callable,
                 logq_func: callable,
                 lr: float = 1e-2,
                 gamma: float = 0.9,
                 eps: float = 1e-8,
                 max_halvings: int = 2,
                 **kwargs):
        super().__init__(particles, **kwargs)
        self.accumulated_grad = accumulated_grad
        #self.rms_grad = gamma * avg_grad
        #self.gamma = gamma
        self.particle_logp = particle_logp
        self.particle_logq = particle_logq
        self.posterior = posterior
        self.interface = interface
        self.logp_func = logp_func
        self.logq_func = logq_func
        self.original_lr = lr
        self.lr = lr
        self.eps = eps
        self.max_halvings = max_halvings
        self.step_counter: torch.Tensor = torch.zeros(particles.shape[0])

    @torch.no_grad()
    def step(self, grad: torch.Tensor, **kwargs):
        active_idx = torch.arange(self.particles.shape[0])
        self.accumulated_grad += grad ** 2
        grad_norm = torch.linalg.norm(grad, dim=1).reshape(-1, 1)
        for i in range(self.max_halvings):

            if len(active_idx) == 0:
                break

            self.step_counter[active_idx] += 1
            self.posterior.clear_cache()
            proposals = self.particles[active_idx] - self.lr[active_idx] * grad[active_idx] / torch.sqrt(
                self.accumulated_grad[active_idx] + self.eps)
            #proposals = self.particles[active_idx] - self.lr[active_idx] * grad[active_idx] / (grad_norm[active_idx] + self.eps)
            prop_logp = self.logp_func(proposals)
            prop_logq = self.logq_func(proposals)

            delta_logw = (prop_logp - prop_logq) - (self.particle_logp[active_idx] - self.particle_logq[active_idx])
            acc_idx = delta_logw > 0.0
            self.particles[active_idx[acc_idx]] = proposals[acc_idx]
            #self.particle_logp[active_idx[acc_idx]] = prop_logp[acc_idx]
            #self.particle_logq[active_idx[acc_idx]] = prop_logq[acc_idx]
            active_idx = active_idx[~acc_idx]
            self.lr[active_idx] = self.lr[active_idx] / 10.0
        print(f'<step_counter> = {torch.mean(self.step_counter)}')

    def reset(self):
        self.lr = self.original_lr
        self.step_counter = 0


class AdaptiveLearningRateParticleOptimizer(ParticleOptimizer):
    def __init__(self, particles: torch.Tensor, lr: float = 1.0, **kwargs):
        super().__init__(particles, **kwargs)
        self.base_lr = lr
        self.step_counter = 0

    @torch.no_grad()
    def step(self, grad: torch.Tensor, log_posterior_values: torch.Tensor = None, **kwargs):
        # Per-particle learning rate adjustment.
        # Particles with a small posterior AND a small gradient get a higher learning rate.
        # Or rather: particles with a small posterior **relative to** the posterior of other particles (since we don't
        # know what the maximum value of the posterior is).
        # max_log_posterior = torch.max(log_posterior_values)
        # posterior_fractions = torch.exp(log_posterior_values - max_log_posterior)
        # grad_norms = torch.linalg.norm(grad, dim=1).view_as(log_posterior_values)

        lr_scaling = torch.linalg.norm(self.particles - self.particles.mean(dim=0).view(1, -1), dim=1)
        # lr_scaling = 1 / torch.sqrt(grad_norms * posterior_fractions)
        # lr_scaling = 1 / torch.sqrt((grad_norms ** 2 + torch.exp(log_posterior_values) ** 2))
        lr = self.base_lr * lr_scaling
        self.particles -= lr.view(-1, 1) * grad
        self.step_counter += 1

    def reset(self):
        self.step_counter = 0


class ParticleStepScheduler(abc.ABC):
    def __init__(self, optimizer: ParticleOptimizer, **kwargs):
        self.optimizer = optimizer
        self.step_counter = 0

    @staticmethod
    def step(self):
        raise NotImplementedError

    def reset(self):
        self.step_counter = 0


class IdentityScheduler(ParticleStepScheduler):
    def __init__(self, optimizer: ParticleOptimizer):
        super().__init__(optimizer)

    def step(self):
        self.step_counter += 1

    def reset(self):
        self.step_counter = 0


class ExponentialDecayScheduler(ParticleStepScheduler):
    def __init__(self, optimizer: ParticleOptimizer, decay_rate: float = 0.999):
        super().__init__(optimizer)
        self.decay_rate = decay_rate

    def step(self):
        self.optimizer.lr *= self.decay_rate
        self.step_counter += 1

    def reset(self):
        self.step_counter = 0


class CosineAnnealingScheduler(ParticleStepScheduler):
    def __init__(self, optimizer: ParticleOptimizer, T_max: int, lr_min: float = 0.0):
        super().__init__(optimizer)
        self.T_max = T_max
        self.lr_min = lr_min
        self.lr_max = self.optimizer.lr

    def step(self):
        self.optimizer.lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
                1 + np.cos(self.step_counter / self.T_max * np.pi)
        )
        self.step_counter += 1

    def reset(self):
        self.lr_max = self.optimizer.lr
        self.step_counter = 0
