from __future__ import annotations
from typing import *
import math

import torch
from torch import nn, optim
from torch.nn import functional as F
import torch.utils.data as data_utils
import torch.distributions as dists
import numpy as np

from laplace import Laplace
from laplace.curvature import BackPackGGN
from laplace.marglik_training import marglik_training


__all__ = ['RandomStrategy', 'ConstantUCBStrategy', 'NeuralUCBStrategy', 'MarglikUCBStrategy']


def get_dataloader(X: torch.Tensor, Y: torch.Tensor, batch_size: int = 512, shuffle: bool = True) -> data_utils.DataLoader:
    return data_utils.DataLoader(
        data_utils.TensorDataset(X, Y),
        batch_size=batch_size, shuffle=shuffle
    )


class BanditStrategy:

    def __init__(self, n_arms: int):
        self.n_arms = n_arms

    def pull_lever(self, X: torch.Tensor) -> int:
        """
        Observe contexts X over K actions (levers), then decide which lever to pull.

        Params:
        -------
        X: torch.Tensor
            Contexts matrix, shape (K, context_dim)

        Returns:
        --------
        k: int >= 0
            The pulled lever, k in {0, ..., K-1}, i.e. zero-indexed
        """
        raise NotImplementedError()


class RandomStrategy(BanditStrategy):
    """
    Always pick arm uniformly at random.
    """
    def __init__(self, n_arms: int):
        super().__init__(n_arms)

    def pull_lever(self, X: torch.Tensor) -> int:
        return np.random.randint(self.n_arms)


class LaplaceBanditStrategy(BanditStrategy):

    def __init__(
        self,
        n_arms: int,
        get_net: Callable,
        train_X: torch.Tensor,
        train_Y: torch.Tensor,
        lr: float = 1e-3,
        wd: float = 1e-2,
        n_epochs: int = 500,
        online_laplace: bool = False,
        tune_nn_hypers: bool = False,
        device: str = 'cpu'
    ):
        super().__init__(n_arms=n_arms)

        self.get_net = get_net
        self.train_X = train_X
        self.train_Y = train_Y
        self.lr = lr
        self.wd = wd
        self.n_epochs = n_epochs
        self.device = device
        self.width, self.depth = self.get_width_depth()

        self.tune_nn_hypers = tune_nn_hypers
        self.online_laplace = online_laplace
        if self.online_laplace:
            # The point of online Laplace is to tune hypers
            print('Overriding tune_nn_hypers to True for online Laplace')
            self.tune_nn_hypers = True

        # Pre-train Laplace
        self.train()

    def pull_lever(self, X: torch.Tensor) -> int:
        """
        Observe contexts X over K actions (levers), then decide which lever to pull.

        Params:
        -------
        X: torch.Tensor
            Contexts matrix, shape (K, context_dim)

        Returns:
        --------
        k: int >= 0
            The pulled lever, k in {0, ..., K-1}, i.e. zero-indexed
        """
        raise NotImplementedError()

    def train(self) -> None:
        net = self.get_net().to(self.device)
        train_loader = get_dataloader(self.train_X, self.train_Y, batch_size=128, shuffle=True)

        if self.online_laplace:
            # Online always tune hyperparams
            la, _, _, _ = marglik_training(
                net, train_loader, likelihood='regression',
                hessian_structure='kron',
                prior_structure='scalar',
                n_epochs=self.n_epochs, backend=BackPackGGN,
                optimizer_kwargs={'lr': self.lr},
                scheduler_cls=optim.lr_scheduler.CosineAnnealingLR,
                scheduler_kwargs={'T_max': self.n_epochs*len(train_loader)},
                marglik_frequency=25
            )
            self.la = la
        else:
            opt = optim.AdamW(net.parameters(), lr=self.lr, weight_decay=self.wd)
            sch = optim.lr_scheduler.CosineAnnealingLR(opt, self.n_epochs*len(train_loader))

            for _ in range(self.n_epochs):
                for x, y in train_loader:
                    x, y = x.to(self.device), y.to(self.device)
                    opt.zero_grad()
                    output = net(x)
                    loss = F.mse_loss(output, y)
                    loss.backward()
                    opt.step()
                    sch.step()

            net.eval()
            self.la = Laplace(
                net, 'regression',
                # Multiplied by dataset size since we have used minibatching
                # during MAP estimation
                prior_precision=self.wd*len(self.train_X),
                subset_of_weights='all',
                hessian_structure='kron',
                backend=BackPackGGN
            )
            self.la.fit(train_loader)

            if self.tune_nn_hypers:
                # Tune prior precision and observation noise
                log_prior = torch.ones(1, requires_grad=True, device=self.device)
                log_sigma = torch.ones(1, requires_grad=True, device=self.device)
                hyper_optimizer = optim.Adam([log_prior, log_sigma], lr=1e-1)
                for _ in range(100):
                    hyper_optimizer.zero_grad()
                    neg_marglik = -self.la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
                    neg_marglik.backward()
                    hyper_optimizer.step()

    def get_prediction(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        f_mean, f_var = self.la(X, pred_type='glm')
        return f_mean.squeeze(), f_var.squeeze() + self.la.sigma_noise**2

    def condition_on_observations(self, X: torch.Tensor, Y: torch.Tensor, retrain=False) -> None:
        self.train_X = torch.cat([self.train_X, X], dim=0)
        self.train_Y = torch.cat([self.train_Y, Y], dim=0)
        if retrain:
            self.train()

    def get_width_depth(self) -> int:
        net = self.get_net()
        x = self.train_X
        max_width = 0
        depth = 0

        for layer in net:
            x = layer(x)

            if type(layer) in [nn.Linear, nn.Conv2d]:
                max_width = max(max_width, x.shape[-1])
                depth += 1

        return max_width, depth


class ConstantUCBStrategy(LaplaceBanditStrategy):
    """
    UCB strategy with a constant exploration parameter `gamma`.
    """
    def __init__(
        self,
        n_arms: int,
        get_net: Callable,
        train_X: torch.Tensor,
        train_Y: torch.Tensor,
        lr: float = 1e-3,
        wd: float = 1e-2,
        n_epochs: int = 500,
        gamma: float = 0.2,
        device: str = 'cpu'
    ):
        # online_laplace and tune_nn_hypers are set to False
        super().__init__(n_arms, get_net, train_X, train_Y, lr, wd, n_epochs, False, False, device)

        # UCB exploration param
        self.gamma = gamma

    def pull_lever(self, X: torch.Tensor) -> int:
        mean, var = self.get_prediction(X)
        ucb = mean.squeeze() + self.gamma * var.squeeze().sqrt()  # (K,)
        return torch.argmax(ucb).item()

    def get_gamma(self, t):
        return self.gamma


class NeuralUCBStrategy(LaplaceBanditStrategy):

    """
    UCB strategy with non-constant `gamma`, following Zhou et al., ICML 2020.
    """
    def __init__(
        self,
        n_arms: int,
        get_net: Callable,
        train_X: torch.Tensor,
        train_Y: torch.Tensor,
        lr: float = 1e-3,
        wd: float = 1e-2,
        n_epochs: int = 500,
        nu: float = 1,
        delta: float = 0.1,
        S: float = 1,
        device: str = 'cpu'
    ):
        # online_laplace and tune_nn_hypers are set to False
        super().__init__(n_arms, get_net, train_X, train_Y, lr, wd, n_epochs, False, False, device)

        # NeuralUCB-specific
        self.C1, self.C2, self.C3 = 1, 1, 1
        self.nu = nu
        self.delta = delta
        self.S = S

    def pull_lever(self, X: torch.Tensor) -> int:
        t = len(self.train_X)  # 0 training data == first time step
        gamma = self.get_gamma(t)
        mean, var = self.get_prediction(X)
        ucb = mean.squeeze() + gamma * var.squeeze().sqrt()  # (K,)
        return torch.argmax(ucb).item()

    @torch.no_grad()
    def get_gamma(self, t):
        try:
            m, L = self.width, self.depth
            sqrtlm = math.sqrt(math.log(self.width))
            lam = self.la.prior_precision.item()
            logdet_post = self.la.log_det_posterior_precision.item()
            logdet_prior = self.la.log_det_prior_precision.item()

            first = math.sqrt(1 + self.C1 * self.width**(-1./6.) * sqrtlm * self.depth**4 * t**(7./6.))
            _secnd = math.sqrt(logdet_post - logdet_prior + self.C2 * m**(-1./6.) \
                            * sqrtlm * L**4 * t**(5./3.) * lam**(-1./6.) \
                            - 2*math.log(self.delta))
            secnd = self.nu * _secnd + math.sqrt(lam) * self.S
            _third = (1 - self.lr * m * lam)**(self.n_epochs/2) * math.sqrt(t/lam) \
                    + m**(-1./6.) * sqrtlm * L**(7./2.) * t**(5./3.) * lam**(-5./3.) * (1 + math.sqrt(t/lam))
            third = (lam + self.C3 * t * L) * _third

            # print(logdet_post, logdet_prior)
            # print(first)
            # print(secnd)
            # print((1 - self.lr * m * lam)**(self.n_epochs/2), math.sqrt(t/lam))
            # print(m**(-1./6.), sqrtlm, L**(7./2.), t**(5./3.), lam**(-5./3.), (1 + math.sqrt(t/lam)))
            # print(third)
            # input()

            return first * secnd + third
        except OverflowError:
            return math.inf


class MarglikUCBStrategy(LaplaceBanditStrategy):
    """
    UCB strategy that relies only on the posterior variance without `gamma`.
    The posterior is tuned using Laplace online marginal likelihood.
    """
    def __init__(
        self,
        n_arms: int,
        get_net: Callable,
        train_X: torch.Tensor,
        train_Y: torch.Tensor,
        lr: float = 1e-3,
        wd: float = 1e-2,
        n_epochs: int = 500,
        device: str = 'cpu',
        online: bool = False
    ):
        # online_laplace and tune_nn_hypers are set to True
        super().__init__(n_arms, get_net, train_X, train_Y, lr, wd, n_epochs, online, True, device=device)

    def pull_lever(self, X: torch.Tensor) -> int:
        mean, var = self.get_prediction(X)  # mean is (K,), var is (K,)
        ucb = mean.squeeze() + var.squeeze().sqrt()  # (K,)
        return torch.argmax(ucb).item()

    def get_gamma(self, t):
        # Stay faithful with the posterior
        return 1


class ConstantTSStrategy(LaplaceBanditStrategy):
    """
    Thompson Sampling strategy with a constant exploration parameter `nu`.
    """
    def __init__(
        self,
        n_arms: int,
        get_net: Callable,
        train_X: torch.Tensor,
        train_Y: torch.Tensor,
        lr: float = 1e-3,
        wd: float = 1e-2,
        n_epochs: int = 500,
        nu: float = 1,
        device: str = 'cpu'
    ):
        super().__init__(n_arms, get_net, train_X, train_Y, lr, wd, n_epochs, True, True, device)

        # UCB exploration param
        self.nu = nu

    def pull_lever(self, X: torch.Tensor):
        f_mean, f_var = self.get_prediction(X)  # (K,) and (K,)
        f_sample = dists.Normal(f_mean, self.nu**2 * f_var)  # (K,)
        return torch.argmax(f_sample)
