from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import GibbsStep, GibbsChain
from plot_utils import *
from nice_kernels import NICEResample, NICERelu, NICEDoubleRelu
from nice_transforms import NICE
from logistic_regression_distributions  import AustralianCreditLR
import hamiltorch

chain_save_file='./Results/Samples/australiancredit_samples'

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')

target = AustralianCreditLR(device=device)
d = target.predictors.shape[-1]

start = torch.zeros([d], device=device).normal_()
batch = hamiltorch.sample(target.log_prob, start, num_samples=501000, burn=1000, sampler=hamiltorch.Sampler.HMC_NUTS, desired_accept_rate=0.65, store_on_GPU = True)
batch=torch.stack(batch).cpu().numpy()

np.save(chain_save_file, batch[1000:, :])
