import hamiltorch
import torch

integrator = hamiltorch.samplers.Integrator.EXPLICIT


def acceptance(h_list_old, h_list_new, temp_list):
    """Returns the log acceptance ratio for the Metroplis-Hastings step.

    Parameters
    ----------
    h_old : torch.tensor
        Previous value of Hamiltonian (1,).
    h_new : type
        New value of Hamiltonian (1,).

    Returns
    -------
    float
        Log acceptance ratio.

    """
    H = 0.
    for h_new, h_old, T in zip(h_list_new, h_list_old, temp_list):
#         import pdb; pdb.set_trace()
        H += float(-h_new + h_old)/T
    return H

def sample(log_prob_func_list, log_prob_func_full, params_init, T_func_list, num_samples=10, num_steps_per_sample=10, step_size=0.1, burn=0, inv_mass=None, debug=False, store_on_GPU = True, pass_grad = None, device = 'cpu'):

    acceptance_list = []
    mass = None

    params = params_init.clone().requires_grad_()
    if not store_on_GPU:
        ret_params = [params.clone().detach().cpu()]
    else:
        ret_params = [params.clone()]

    ADAPT_STEP_SIZE = False
    num_rejected = 0

    hamiltorch.util.progress_bar_init('Sampling', num_samples, 'Samples')
    for n in range(num_samples):
        hamiltorch.util.progress_bar_update(n)
        try:
            momentum = hamiltorch.samplers.gibbs(params, mass=mass)

            h_list_old = []

            inv_mass = torch.ones_like(params).detach() / len(log_prob_func_list)
            for log_prob_func in log_prob_func_list:
                h_list_old.append(hamiltonian(params, momentum, log_prob_func, inv_mass=inv_mass, normalizing_const = float(len(log_prob_func_list))))

#             ham = hamiltorch.samplers.hamiltonian(params, momentum, log_prob_fun_full)

#             import pdb; pdb.set_trace()
            inv_mass = None
            leapfrog_params, leapfrog_momenta = hamiltorch.samplers.leapfrog(params, momentum, log_prob_func_full, steps=num_steps_per_sample,
                                                                             step_size=step_size, inv_mass=inv_mass, integrator = integrator)

            params = leapfrog_params[-1].to(device).detach().requires_grad_()
            momentum = leapfrog_momenta[-1].to(device)

            h_list_new = []

            inv_mass = torch.ones_like(params).detach() / len(log_prob_func_list)
            for log_prob_func in log_prob_func_list:
                h_list_new.append(hamiltonian(params, momentum, log_prob_func, inv_mass=inv_mass, normalizing_const = float(len(log_prob_func_list))))

            temp_list = [T(n) for T in T_func_list]

#             new_ham = hamiltorch.samplers.hamiltonian(params, momentum, log_prob_fun_full)#, jitter=jitter, softabs_const=softabs_const, explicit_binding_const=explicit_binding_const, normalizing_const=normalizing_const, sampler=sampler, integrator=integrator, metric=metric)

#             rho = min(0., hamiltorch.samplers.acceptance(ham, new_ham))
            acc = acceptance(h_list_old, h_list_new, temp_list)
            acceptance_list.append(acc)
            rho = min(0., acc)
            if debug == 1:
                print('Step: {}, Current Hamiltoninian: {}, Proposed Hamiltoninian: {}'.format(n,sum(h_list_old),sum(h_list_new)))

#                 print('N:', n)
#                 print(h_list_old)
#                 print(h_list_new)

            if rho >= torch.log(torch.rand(1)):
                if debug == 1:
                    print('Accept rho: {}'.format(rho))
                if n > burn:
                    if store_on_GPU:
                        ret_params.append(leapfrog_params[-1])
                    else:
                        # Store samples on CPU
                        ret_params.append(leapfrog_params[-1].cpu())
                        # ret_params.extend([lp.detach().cpu() for lp in leapfrog_params])
            else:
                num_rejected += 1
                params = ret_params[-1].to(device)
                if n > burn:
                    # leapfrog_params = ret_params[-num_steps_per_sample:] ### Might want to remove grad as wastes memory
                    if store_on_GPU:
                        ret_params.append(ret_params[-1].to(device))
                    else:
                        # Store samples on CPU
                        ret_params.append(ret_params[-1].cpu())
                if debug == 1:
                    print('REJECT')

            if ADAPT_STEP_SIZE and n <= burn:
                if n < burn:
                    step_size, eps_bar, H_t = adaptation(rho, n, step_size_init, H_t, eps_bar, desired_accept_rate=desired_accept_rate)
                if n  == burn:
                    step_size = eps_bar
                    print('Final Adapted Step Size: ',step_size)

            # if not store_on_GPU: # i.e. delete stuff left on GPU
            #     # This adds approximately 50% to runtime when using colab 'Tesla P100-PCIE-16GB'
            #     # but leaves no memory footprint on GPU after use.
            #     # Might need to check if variables exist as a log prob error could occur before they are assigned!
            #
            #     del momentum, leapfrog_params, leapfrog_momenta, ham, new_ham
            #     torch.cuda.empty_cache()

        except hamiltorch.util.LogProbError:
            num_rejected += 1
            params = ret_params[-1].to(device)
            if n > burn:
                # leapfrog_params = ret_params[-num_steps_per_sample:] ### Might want to remove grad as wastes memory
                if store_on_GPU:
                    ret_params.append(ret_params[-1].to(device))
                else:
                    # Store samples on CPU
                    ret_params.append(ret_params[-1].cpu())
            if debug == 1:
                print('REJECT')
            if ADAPT_STEP_SIZE and n <= burn:
                # print('hi')
                rho = float('nan') # Acceptance rate = 0
                # print(rho)
                step_size, eps_bar, H_t = adaptation(rho, n, step_size_init, H_t, eps_bar, desired_accept_rate=desired_accept_rate)
            if ADAPT_STEP_SIZE and n  == burn:
                step_size = eps_bar
                print('Final Adapted Step Size: ',step_size)

        if not store_on_GPU: # i.e. delete stuff left on GPU
            # This adds approximately 50% to runtime when using colab 'Tesla P100-PCIE-16GB'
            # but leaves no memory footprint on GPU after use in normal HMC mode. (not split)
            # Might need to check if variables exist as a log prob error could occur before they are assigned!
            momentum = None; leapfrog_params = None; leapfrog_momenta = None; ham = None; new_ham = None

            del momentum, leapfrog_params, leapfrog_momenta, ham, new_ham
            torch.cuda.empty_cache()

                # var_names = ['momentum', 'leapfrog_params', 'leapfrog_momenta', 'ham', 'new_ham']
                # [util.gpu_check_delete(var, locals()) for var in var_names]
            # import pdb; pdb.set_trace()


    # import pdb; pdb.set_trace()
    hamiltorch.util.progress_bar_end('Acceptance Rate {:.2f}'.format(1 - num_rejected/num_samples)) #need to adapt for burn
    if ADAPT_STEP_SIZE and debug == 2:
        return list(map(lambda t: t.detach(), ret_params)), step_size
    elif debug == 2:
        return list(map(lambda t: t.detach(), ret_params)), 1 - num_rejected/num_samples
    else:
        return list(map(lambda t: t.detach(), ret_params)), acceptance_list


def hamiltonian(params, momentum, log_prob_func, jitter=0.01, normalizing_const=1., softabs_const=1e6, explicit_binding_const=100, inv_mass=None, ham_func=None):
    ### MADE SOME EDITS TO MAKE IT EASIER TO RUN HERE
    """Computes the Hamiltonian as a function of the parameters and the momentum.

    Parameters
    ----------
    params : torch.tensor
        Flat vector of model parameters: shape (D,), where D is the dimensionality of the parameters.
    momentum : torch.tensor
        Flat vector of momentum, corresponding to the parameters: shape (D,), where D is the dimensionality of the parameters.
    log_prob_func : function
        A log_prob_func must take a 1-d vector of length equal to the number of parameters that are being sampled.
    jitter : float
        Jitter is often added to the diagonal to the metric tensor to ensure it can be inverted.
        `jitter` is a float corresponding to scale of random draws from a uniform distribution.
    normalizing_const : float
        This constant is currently set to 1.0 and might be removed in future versions as it plays no immediate role.
    softabs_const : float
        Controls the "filtering" strength of the negative eigenvalues. Large values -> absolute value. See Betancourt 2013.
    explicit_binding_const : float
        Only relevant to Explicit RMHMC. Corresponds to the binding term in Cobb et al. 2019.
    inv_mass : torch.tensor or list
        The inverse of the mass matrix. The inv_mass matrix is related to the covariance of the parameter space (the scale we expect it to vary). Currently this can be set
        to either a diagonal matrix, via a torch tensor of shape (D,), or a full square matrix of shape (D,D). There is also the capability for some
        integration schemes to implement the inv_mass matrix as a list of blocks. Hope to make that more efficient.
    ham_func : type
        Only related to semi-separable HMC. This part of hamiltorch has not been fully integrated yet.
    sampler : Sampler
        Sets the type of sampler that is being used for HMC: Choice {Sampler.HMC, Sampler.RMHMC, Sampler.HMC_NUTS}.
    integrator : Integrator
        Sets the type of integrator to be used for the leapfrog: Choice {Integrator.EXPLICIT, Integrator.IMPLICIT, Integrator.SPLITTING,
        Integrator.SPLITTING_RAND, Integrator.SPLITTING_KMID}.
    metric : Metric
        Determines the metric to be used for RMHMC. E.g. default is the Hessian hamiltorch.Metric.HESSIAN.

    Returns
    -------
    torch.tensor
        Returns the value of the Hamiltonian: shape (1,).

    """


    if type(log_prob_func) is not list:
        log_prob = log_prob_func(params)

        if hamiltorch.util.has_nan_or_inf(log_prob):
            print('Invalid log_prob: {}, params: {}'.format(log_prob, params))
            raise util.LogProbError()

    elif type(log_prob_func) is list: # I.e. splitting!
        log_prob = 0
        for split_log_prob_func in log_prob_func:
            # Don't propogate gradients for saving  GPU memory usage (Sampler.HMC code does not explicitly calculate dH/dp etc...)
            with torch.no_grad():
                log_prob = log_prob + split_log_prob_func(params)

                if hamiltorch.util.has_nan_or_inf(log_prob):
                    print('Invalid log_prob: {}, params: {}'.format(log_prob, params))
                    raise hamiltorch.util.LogProbError()


    potential = -log_prob#/normalizing_const
    if inv_mass is None:
        kinetic = 0.5 * torch.dot(momentum, momentum)#/normalizing_const
    else:
        if type(inv_mass) is list:
            i = 0
            kinetic = 0
            for block in inv_mass:
                it = block[0].shape[0]
                kinetic = kinetic +  0.5 * torch.matmul(momentum[i:it+i].view(1,-1),torch.matmul(block,momentum[i:it+i].view(-1,1))).view(-1)#/normalizing_const
                i += it
        #Assum G is diag here so 1/Mass = G inverse
        elif len(inv_mass.shape) == 2:
            kinetic = 0.5 * torch.matmul(momentum.view(1,-1),torch.matmul(inv_mass,momentum.view(-1,1))).view(-1)#/normalizing_const
        else:
            kinetic = 0.5 * torch.dot(momentum, inv_mass * momentum)#/normalizing_const
    hamiltonian = potential + kinetic / normalizing_const
    # hamiltonian = hamiltonian

    return hamiltonian
