import time
import torch
import numpy as np
from src import kernels
from src.sliced.Def_Divergence import compute_max_DSSD_eff, max_DSSVGD
from src.sliced.Util import SVGD_AdaGrad_update
from src.sliced.Kernel import SE_kernel, d_SE_kernel, dd_SE_kernel, repulsive_SE_kernel
from src.models.bnn import NeuralNetworkEnsemble
from src.models.mvnmixture import MultivariateNormalMixture
from tqdm import tqdm

FUDGE_FACTOR = 1e-10

class SVGD:

    def __init__(
            self,
            particles,
            model,
            k=None,
            store_particles_history=False,
            store_force_history=False,
            store_nn_metrics=False,
            test_dataset=None,
            device='cpu'
        ):
        if type(particles)==np.ndarray:
            particles = torch.from_numpy(particles)
        assert type(particles)==torch.Tensor, "The particles argument must be either a numpy ndarray or a PyTorch tensor"
        assert particles.dim()==2, "The particles argument must have exactly 2 axes/dimensions"
        self.N, self.d = particles.shape
        self.model = model
        self.device = device
        self.store_particles_history = store_particles_history
        self.store_force_history = store_force_history
        self.store_nn_metrics = store_nn_metrics
        if test_dataset:
            self.test_dataset = test_dataset
        if self.store_particles_history:
            self.particles_history = particles.unsqueeze(0)
        # self.current_iterations = 0
        self.particles_current = particles.clone().detach()
        if k == None:
            self.k = [{'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None, 'factors': None}]
        # kernel_families = list(set([k['family'] for k in k_1] + [k['family'] for k in k_2]))
        # preconditionings = list(set([k['preconditioning'] for k in k_1] + [k['preconditioning'] for k in k_2]))

    def set_history_tensors(self, num_iterations):
        if self.store_particles_history:
            self.particles_history = torch.vstack([self.particles_history, torch.zeros(num_iterations, self.N, self.d)])
            self.marginal_variance_history = torch.zeros(num_iterations, self.d)
            self.damv_history = torch.zeros(num_iterations, )
        if self.store_force_history:
            self.force_driving_history = torch.zeros(num_iterations, self.N)
            self.force_repulsive_history = torch.zeros(num_iterations, self.N)
        if self.store_nn_metrics:
            self.ll_array_history = torch.zeros(num_iterations, self.N)
            self.ll_history = torch.zeros(num_iterations, )
            self.rmse_history = torch.zeros(num_iterations, )

    def update_history_tensors(self, iteration, force_driving, force_repulsive):
        if self.store_particles_history:
            self.particles_history[iteration+1] = self.particles_current.clone()
            self.marginal_variance_history[iteration] = torch.var(self.particles_history[iteration], dim=0)
            self.damv_history[iteration] = torch.mean(self.marginal_variance_history[iteration])
        if self.store_force_history:
            self.force_driving_history[iteration] = torch.norm(force_driving.clone().detach() / self.N, dim=1)
            self.force_driving_history[iteration] = torch.norm(force_repulsive.clone().detach() / self.N, dim=1)
        if self.store_nn_metrics:
            self.model.evaluate(self.test_dataset)
            self.ll_array_history[iteration] = self.model.ll_array
            self.ll_history[iteration] = self.model.ll
            self.rmse_history[iteration] = self.model.rmse

    def compute_damv(self):
        self.marginal_variances = torch.var(self.particles_current, dim=0)
        self.damv = torch.mean(self.marginal_variances).item()

    def evaluate(self):
        self.rmse = self.model.rmse
        self.ll = self.model.ll
        self.ll_array = self.model.ll_array

    def compute_initial_forces(self):
        self.parf_initial = self.phi_R.norm(p=torch.inf, dim=1).mean().item()
        self.paksg_initial = self.phi_G.norm(p=torch.inf, dim=1).mean().item()

    def compute_final_forces(self):
        self.parf_final = self.phi_R.norm(p=torch.inf, dim=1).mean().item()
        self.paksg_final = self.phi_G.norm(p=torch.inf, dim=1).mean().item()

    def update(self, num_iterations, eps=0.01, alpha=0.9, progress=False, **kwargs):

        # print('Running {}'.format(self))

        self.num_iterations = num_iterations
        self.set_history_tensors(num_iterations)

        if type(self)==SSVGD:
            g_lr = kwargs['g_lr']
            n_g_update = kwargs['n_g_update']
            if 'g_update_every' in kwargs.keys():
                g_update_every = kwargs['g_update_every']
            else:
                g_update_every = 1
            self.slice_optimiser = torch.optim.Adam([self.G], lr=g_lr, betas=(0.5,0.9))
            kernel_hyper_slicedSVGD = {'bandwidth': None}
            band_scale = 1

        time_start = time.time()

        if 'data_loader' in kwargs.keys():
            data_loader = kwargs['data_loader']
            num_features = len(data_loader.dataset[0][0])-1

        for iteration in tqdm(range(num_iterations), disable=not progress):
        # for iteration in range(num_iterations):

            # if iteration % 100 == 0:
            #     print('Iteration: {}'.format(iteration))
            
            particles = self.particles_current.clone().detach().requires_grad_()
            if type(self.model) == NeuralNetworkEnsemble:
                # self.model = NeuralNetworkEnsemble(
                #     num_features,
                #     self.model.train_mean_X,
                #     self.model.train_mean_y,
                #     self.model.train_std_X,
                #     self.model.train_std_y,
                #     particles
                # )
                score = self.model.score(data_loader)
            else:
                # print(particles.shape)
                # abc
                score = self.model.score(particles)

            # Compute update direction and gradient/repulsive forces

            if type(self)==SSVGD:
                if iteration % g_update_every == 0:
                    G_n = self.optimise_g(particles, score, n_g_update, kernel_hyper_slicedSVGD, band_scale)
                else:
                    G_n = self.G / (torch.norm(self.G, 2, dim=-1, keepdim=True).to(self.device) + FUDGE_FACTOR)
                self.phi, self.phi_R = max_DSSVGD(
                    particles, None, SE_kernel, repulsive_SE_kernel,
                    r=self.O_r, g=G_n, flag_median=True, median_power=0.5,
                    kernel_hyper=kernel_hyper_slicedSVGD, score=score,
                    bandwidth_scale=band_scale, repulsive_coef=1, flag_repulsive_output=True
                )
                self.phi_G = self.phi - self.phi_R

            if type(self)==hSVGD:
                k = 0
                for k_component in self.k_1:
                    k = k + k_component['weight'] * compute_k(k_component, self.particles_current)
                self.phi_R = 0
                for k_component in self.k_2:
                    self.phi_R = self.phi_R + k_component['weight'] * compute_dk(k_component, self.particles_current)
                self.phi_G = torch.matmul(k, score)
                self.phi_G /= self.N
                self.phi_R /= self.N
                self.phi = self.phi_G + self.phi_R

            # Adagrad
            if iteration == 0:
                historical_grad = self.phi ** 2
            else:
                historical_grad = alpha * historical_grad + (1 - alpha) * (self.phi ** 2)
            phi_adagrad = torch.div(self.phi, torch.sqrt(historical_grad) + FUDGE_FACTOR)

            # Set new particle positions
            self.particles_current = self.particles_current + eps * phi_adagrad
            self.update_history_tensors(iteration, self.phi_G, self.phi_R)
            if type(self.model) == NeuralNetworkEnsemble:
                self.model.update_params(eps * phi_adagrad)

            if iteration==0:
                self.compute_initial_forces()

            if torch.isnan(self.phi).any():
                import pdb;pdb.set_trace()

        time_end = time.time()
        self.time_seconds = time_end - time_start
        # print('Computing DAMV')
        self.compute_damv()
        # print('Computing forces')
        self.compute_final_forces()
        if type(self.model) == NeuralNetworkEnsemble:
            self.model.evaluate(self.test_dataset)
            self.evaluate()
        if type(self.model) == MultivariateNormalMixture:
            self.ksd = self.model.kernelized_stein_discrepancy(self.particles_current)
            

class hSVGD(SVGD):

    def __init__(
            self,
            particles,
            model,
            k_1=None,
            k_2=None,
            store_particles_history=False,
            store_force_history=False,
            store_nn_metrics=False,
            test_dataset=None
        ):
        super().__init__(
            particles,
            model,
            store_particles_history=store_particles_history,
            store_force_history=store_force_history,
            store_nn_metrics=store_nn_metrics,
            test_dataset=test_dataset
        )
        if k_1 == None:
            self.k_1 = [{'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None, 'factors': None}]
        elif type(k_1)==dict:
            self.k_1 = [k_1]
        else:
            self.k_1 = k_1

        if k_2 == None:
            self.k_2 = [{'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None, 'factors': None}]
        elif type(k_2)==dict:
            self.k_2 = [k_2]
        else:
            self.k_2 = k_2
        del self.k

class SSVGD(SVGD):

    def __init__(
            self,
            particles,
            model,
            k=None,
            store_particles_history=False,
            store_force_history=False,
            store_nn_metrics=False,
            test_dataset=None
        ):
        super().__init__(
            particles,
            model,
            k=k,
            store_particles_history=store_particles_history,
            store_force_history=store_force_history,
            store_nn_metrics=store_nn_metrics,
            test_dataset=test_dataset
        )
        self.O_r = torch.eye(self.d, device=self.device) # (d, d) Score projections
        self.G = torch.eye(self.d, device=self.device).requires_grad_(True) # (d, d) Slice matrics (particle projections)
        # self.particles_history = particles.unsqueeze(0) # Slice matrics (particle projections)

    def optimise_g(self, particles, score, n_g_update, kernel_hyper_slicedSVGD, band_scale):
        for _ in range(n_g_update):
            self.slice_optimiser.zero_grad()
            G_n = self.G / (torch.norm(self.G, 2, dim=-1, keepdim=True).to(self.device) + FUDGE_FACTOR)
            divergence, _ = compute_max_DSSD_eff(
                particles.detach(), particles.clone().detach(), None, SE_kernel, d_SE_kernel, dd_SE_kernel,
                r=self.O_r, g=G_n, kernel_hyper=kernel_hyper_slicedSVGD,
                score_samples1=score, score_samples2=score.clone(),
                flag_median=True, flag_U=False, median_power=0.5, bandwidth_scale=band_scale)
            (-divergence).backward()
            self.slice_optimiser.step()
            G_n = self.G / (torch.norm(self.G, 2, dim=-1, keepdim=True).to(self.device) + FUDGE_FACTOR)
        return G_n.clone().detach()
    
def compute_k(k_component, particles):
    k_family = k_component['family']
    if k_family == 'rbf':
        pairwise_dist, median_dist = kernels.get_pairwise_dist(particles)
        bandwidth = ( median_dist**2 / np.log(particles.shape[0]) ) * k_component['bandwidth_factor'] + FUDGE_FACTOR
        return kernels.rbf(pairwise_dist, bandwidth)
    if k_family == 'imq':
        pairwise_dist, median_dist = kernels.get_pairwise_dist(particles)
        bandwidth = ( median_dist**2 / np.log(particles.shape[0]) ) * k_component['bandwidth_factor'] + FUDGE_FACTOR
        return kernels.imq(pairwise_dist, bandwidth)
    if k_family == 'rbf_scale_by_dim':
        pairwise_dist, median_dist = kernels.get_pairwise_dist(particles)
        bandwidth = ( median_dist**2 / np.log(particles.shape[0]) ) * k_component['bandwidth_factor'] + FUDGE_FACTOR
        return kernels.rbf_scale_by_dim(pairwise_dist, bandwidth, k_component['factors'])
    if k_family == 'rbf_by_dim':
        pairwise_dist_by_dim, median_dist_by_dim = kernels.get_pairwise_dist_by_dim(particles)
        bandwidth_by_dim = ( median_dist_by_dim**2 / np.log(particles.shape[0]) ) * k_component['bandwidth_factor'] + FUDGE_FACTOR
        # k = torch.exp(-torch.div(pairwise_dist_by_dim.reshape(N,N,d), bandwidth_by_dim))
        return kernels.rbf_by_dim(pairwise_dist_by_dim, bandwidth_by_dim)
    else:
        raise Exception('Invalid family {} provided for kernel'.format(k_family))

def compute_dk(k_component, particles):
    k_family = k_component['family']
    if k_family == 'rbf':
        pairwise_dist, median_dist = kernels.get_pairwise_dist(particles)
        bandwidth = ( median_dist**2 / np.log(particles.shape[0]) ) * k_component['bandwidth_factor'] + FUDGE_FACTOR
        return kernels.d_rbf(pairwise_dist, particles, bandwidth)
    if k_family == 'imq':
        pairwise_dist, median_dist = kernels.get_pairwise_dist(particles)
        bandwidth = ( median_dist**2 / np.log(particles.shape[0]) ) * k_component['bandwidth_factor'] + FUDGE_FACTOR
        return kernels.d_imq(pairwise_dist, particles, bandwidth)
    if k_family == 'rbf_by_dim':
        pairwise_dist_by_dim, median_dist_by_dim = kernels.get_pairwise_dist_by_dim(particles)
        bandwidth_by_dim = ( median_dist_by_dim**2 / np.log(particles.shape[0]) ) * k_component['bandwidth_factor'] + FUDGE_FACTOR
        return kernels.d_rbf_by_dim(pairwise_dist_by_dim, particles, bandwidth_by_dim)
    else:
        raise Exception('Invalid family {} provided for kernel'.format(k_family))