import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from scipy.stats import norm
import pickle
import math 
from copy import deepcopy
import tqdm
from .priors import spike_slab_2GMM, spike_slab_2GMM_pyro, isotropic_gauss_prior
from .BBP import BBP_Bayes_RegNet
from .PyroNN import PyroNN
from scipy.special import logsumexp
import numpy as np 

import pyro
import pyro.distributions as dist
from pyro.distributions import TorchDistribution
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroSample
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO, Predictive, NUTS, MCMC, HMC

MACHINE_TINY = torch.finfo(torch.float64).tiny

class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x_data = x 
        self.y_data = y
    
    def __len__(self):
        return self.x_data.shape[0]

    def __getitem__(self, idx):
        return self.x_data[idx], self.y_data[idx]

class Model():
    def __init__(self, n, d, s, p, scale_data, batch_size=None, F=None, cuda=False, W=None, L=None, verbose=True):
        self.n = n
        self.batch_size = batch_size or n 
        self.n_batch = self.n / self.batch_size
        self.d = d
        self.s = s
        self.p = p
        # self.m = math.ceil(s + 1e-6) if math.isfinite(s) else float("Inf")
        self.m = math.ceil(max([s, s + 1 - 1/p]) + 1e-6)
        self.c_d_m = 1 / (1 + 2 * self.d * math.exp(1) * (2 ** self.m * math.exp(self.m)) / math.sqrt(self.m))
        self.omega = d * (1/p - 1/2) if p < 2 else 0
        self.nu = 0.5 * (self.s / self.omega - 1 ) if self.omega > 0  else float("Inf")
        self.F = F
        self.scale_data = scale_data
        self.cuda = cuda
        self.W_n = W
        self.L_n = L
        self.setup_params()
        if verbose:
            self.info(verbose=verbose)

    def setup_params(self):
        self.W_0 = 6 * self.d * self.m * (self.m + 2) + 2 * self.d
        alpha = max([self.d/self.p-self.s, 0])
        self.xi = min([max([1, alpha]) * (1/self.nu + 1/self.d), 1.0])
        if self.W_n is not None:
            self.N_n = math.ceil(self.W_n / self.W_0)
        else:
            self.N_n = math.ceil(self.n ** (self.d / (2 * self.s + self.d)))
            self.W_n = math.ceil(self.N_n * self.W_0)
        self.lambda_n = self.N_n ** (-self.s/self.d - (1/self.nu + 1/self.d) * alpha) / (math.log(self.N_n))
        if self.L_n is None:
            self.L_n = 3 + 2 * max([0, math.ceil(5 + math.log2(3) * max([self.d, self.m])
                - math.log2(self.lambda_n) - math.log2(self.c_d_m)) *\
                math.ceil(math.log2(max([self.d, self.m])))])
        self.T_n = (self.d + 1) * self.W_n + \
            (self.L_n - 2) * self.W_n * (self.W_n + 1) + self.W_n + 1
        self.S_n = min([math.ceil((self.L_n - 1) * (self.W_0 ** 2) \
            * self.N_n + self.N_n), self.T_n])
        self.B_n = 10 * self.N_n ** self.xi
        self.eps_n = self.n ** (-self.s / (2 * self.s + self.d)) * math.log(self.n) ** (1.5)

        self.r_n = self.S_n / self.T_n
        self.K_0 = 5
        self.eta_n = math.exp(- self.K_0 * self.n * (self.eps_n**2) / self.S_n)
        self.a_n = math.exp(math.log(self.eps_n) - math.log(72.0) - \
            math.log(self.L_n) - (self.L_n - 1) * math.log(max([self.B_n, 1])) - \
            self.L_n * math.log(self.W_n + 1))

        self.pi2 = self.r_n
        self.pi1 = 1 - self.pi2
        
        self.scale_lambda2 = self.B_n / math.sqrt((2 * (self.K_0 + 1) * self.n * self.eps_n ** 2))
        self.scale_lambda1 = self.a_n / norm.isf(self.pi2 / self.pi1 *\
            max([MACHINE_TINY, 0.5 * self.eta_n - norm.sf(self.a_n / self.scale_lambda2)])
        )

    def info(self, verbose=True):
        if verbose:
            # print(f"Total Parameter Size: {self.T_n}")
            # print(f"Sparse Ratio: {self.r_n}")
            print(f"Depth of NN: {self.L_n}")
            print(f"Width of Each Layer: {self.W_n}")
            # print(f"Convergence Rate: {self.eps_n}")
            print(f"scale_param: {self.scale_lambda1} and {self.scale_lambda2}")
            print(f"mixture ratio: {self.pi1} and {self.pi2}")
            print(f"is_cuda: {self.cuda}")

        return {"L": self.L_n, "W": self.W_n, "scale_param": [self.scale_lambda1, self.scale_lambda2], "mixture": [self.pi1, self.pi2]}

    def setup_model(self, manual_sigma1=False, lr=1e-3, pyro=False, prior="2GMM"):
        W = self.W_n
        L = self.L_n
        mu1 = 0.0
        mu2 = 0.0
        if manual_sigma1:
            self.scale_lambda1 = max(manual_sigma1, self.scale_lambda1)
            print(f"scale_lambda1 set to {manual_sigma1}")
        sigma1 = self.scale_lambda1
        sigma2 = self.scale_lambda2 
        std_sigma = 1.0
        pi = self.pi1

        if pyro:
            device = torch.device("cuda" if self.cuda else "cpu")
            mu1 = torch.tensor(mu1, device=device)
            mu2 = torch.tensor(mu2, device=device)
            sigma1 = torch.tensor(sigma1, device=device)
            sigma2 = torch.tensor(sigma2, device=device)
            std_sigma = torch.tensor(std_sigma, device=device)
            pi = torch.tensor(pi, device=device)

            if prior == "2GMM":
                prior = spike_slab_2GMM_pyro(mu1, mu2, sigma1, sigma2, pi)
            else:
                prior = dist.Normal(mu2, sigma2) # dist.Normal(mu2, std_sigma)

            self.model = PyroNN(
                input_dim=self.d,
                W=W,
                L=L,
                scale_data=self.scale_data,
                cuda=self.cuda,
                prior=prior,
                F = self.F
            )
        else:

            if prior == "2GMM":
                prior = spike_slab_2GMM(mu1, mu2, sigma1, sigma2, pi)
            else:
                prior = isotropic_gauss_prior(mu2, sigma2) # isotropic_gauss_prior(mu2, std_sigma)
            
            self.model = BBP_Bayes_RegNet(
                lr=lr, 
                input_dim=self.d,
                cuda=self.cuda,
                output_dim=1,
                Nbatches=self.n_batch,
                n_hid=W,
                prior_instance=prior,
                scale_data=torch.tensor(self.scale_data),
                F=self.F,
                n_layer=L-1)
    
    def run_mcmc(self, x, y, algorithm="hmc", nuts_params={"max_tree_depth": 5},
        hmc_params={"adapt_step_size": True, "adapt_mass_matrix": True, "full_mass": False},
        mcmc_params={"num_samples": 1000, "warmup_steps": 1000, "num_chains": 1}):
        if algorithm == "nuts":
            self.mcmc_kernel = NUTS(self.model, **nuts_params)
        elif algorithm == "hmc":
            self.mcmc_kernel = HMC(self.model, **hmc_params)
        else:
            raise NotImplementedError
        self.mcmc = MCMC(
            self.mcmc_kernel,
            **mcmc_params
        )
        self.mcmc.run(x=x, y=y)

    def train_pyro_deterministic(self, x, y, n_epochs=1000, verbose=10, best=False):
        train_data = CustomDataset(x, y)
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        verbose_at = n_epochs // verbose if verbose else False
        if best:
            best_loss = float("inf")
        for step in tqdm.trange(n_epochs):
            loss = 0
            for idx, batch in enumerate(train_loader):
                x_train, y_train = batch
                loss_i = self.model.fit_deterministic(x_train, y_train)
                loss += loss_i
            loss /= self.n_batch
            if best:
                if best_loss > loss:
                    best_model = deepcopy(self.model)
                    best_loss = deepcopy(loss)
            if verbose:
                if (step + 1) % verbose_at == 0:
                    print(f"loss: {loss}")
        
        if best:
            self.model = deepcopy(best_model)    

    def train_BBP(self, x, y, n_epochs=1000, verbose=10, n_samples=10, best=False):
        train_data = CustomDataset(x, y)
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        verbose_at = n_epochs // verbose if verbose else False
        samples = 1
        if best:
            best_loss = float("inf")
        for step in tqdm.trange(n_epochs):
            edkl = 0
            mlpdw = 0
            loss = 0
            for idx, batch in enumerate(train_loader):
                x_train, y_train = batch
                edkl_i, mlpdw_i, loss_i = self.model.fit(x_train, y_train, samples=samples, warmup=False)
                edkl += edkl_i
                mlpdw += mlpdw_i 
                loss += loss_i
            edkl /= self.n_batch 
            mlpdw /= self.n_batch 
            loss /= self.n_batch
            if best:
                if best_loss > loss:
                    best_model = deepcopy(self.model)
                    best_loss = deepcopy(loss)
            if verbose:
                if (step + 1) % verbose_at == 0:
                    print(f"loss: {loss} negative_log_likelihood: {mlpdw}, KL: {edkl}")
        
        if best:
            self.model = deepcopy(best_model)
