import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import matplotlib.pyplot as plt

import pandas as pd
import numpy as np

from tqdm import tqdm
import dill as pickle
import warnings
from heteroskedastic_nns.parallel_model import SplitParallelFF, ParallelFF
from heteroskedastic_nns.model import SingleFF
from heteroskedastic_nns.parallel_ft import ParallelDFT

from matplotlib.colors import LogNorm, SymLogNorm
from matplotlib.ticker import MaxNLocator
import seaborn as sns
import math

mse = torch.nn.MSELoss()



def mixture_chunk_prediction(preds, n_chunks):

        raw_means = torch.chunk(preds["mean"], chunks=n_chunks, dim=1)

        raw_vars = torch.chunk(preds["precision"].pow(-1), chunks=n_chunks, dim=1)

        all_means = torch.zeros_like(raw_means[0])
        all_precs = torch.zeros_like(raw_vars[0])

        # output shape is n datapoints, n models, n dim
        for ch in range(n_chunks):
            mu_bar = raw_means[ch].mean(dim=1)
            
            mean_sq_bar = raw_means[ch].pow(2).mean(dim=1)

            var_bar = raw_vars[ch].mean(dim=1)

            prec = (var_bar + mean_sq_bar - mu_bar.pow(2)).pow(-1)

            all_means[:, ch, :] = mu_bar
            all_precs[:, ch, :] = prec

        return {
            "mean": all_means,
            "precision": all_precs
        }

def ci_obs_avg(x, y, model, tau, device, prec_param=True, mixture=1):

    # Create a standard normal distribution tensor
    normal_distribution = torch.distributions.Normal(0, 1)

    # Get the quantile
    
    alpha = 1 - ((1-tau) / 2)
    quantile = normal_distribution.icdf(torch.tensor(alpha)).to(device)
    
    model = model.to(device)

    if mixture > 1:
        preds = mixture_chunk_prediction(model(x), n_chunks=mixture)
    else:
        preds = model(x)

    
    # len(y) x num_models x 1
    uq_taus = preds['mean'] + quantile * preds['precision'].pow(-.5)
    lq_taus = preds['mean'] - quantile * preds['precision'].pow(-.5)
 
    
    y = y.view(len(y), 1, 1)
    y_expanded = y.repeat(1, uq_taus.shape[1], 1)
   
    # returns num_models x 1 tensor where each is the proportion of data pts 
    # the model covered
    return ((uq_taus > y_expanded) * (lq_taus < y_expanded)).float().mean(dim=0)#, uq_taus, lq_taus, preds['mean']




def p_obs_avg(x, y, model, tau, device, prec_param=True, mixture=1):

    # Create a standard normal distribution tensor
    normal_distribution = torch.distributions.Normal(0, 1)

    # Get the quantile
    quantile = normal_distribution.icdf(torch.tensor(tau)).to(device)

    if mixture > 1:
        preds = mixture_chunk_prediction(model(x), n_chunks=mixture)
    else:
        preds = model(x)

    
    # len(y) x num_models x 1
    q_taus = preds['mean'] + quantile * preds['precision'].pow(-.5)
 
    
    y = y.view(len(y), 1, 1)
    y_expanded = y.repeat(1, q_taus.shape[1], 1)
    
    # returns num_models x 1 tensor where each is the proportion of data pts 
    # the model covered
    return (q_taus > y_expanded).float().mean(dim=0)

def expected_calibration_error(x, y, model, device, samples=1000, mixture=1):
    ece = 0
    
    # samples is the number of taus to draw
    for s in range(samples):
        tau = torch.rand(1).to(device)
        ece += (p_obs_avg(x, y, model, tau, device, mixture=mixture) - tau).abs()
    
    return ece / samples

def expected_ci_coverage(x, y, model, device, samples=1000, mixture=1):
    eci = 0
    
    
    # samples is the number of taus to draw
    for s in range(samples):
        tau = torch.rand(1).to(device)
        eci += (ci_obs_avg(x.to(device), y.to(device), model, tau, device, mixture=mixture) - tau).abs()
        
        #if s % (samples// 10) == 0:
        #    print(s)
    
    return eci / samples


def p_obs_avg_nfe(y, model, tau, device, prec_param=True, mixture=False):

    # Create a standard normal distribution tensor
    normal_distribution = torch.distributions.Normal(0, 1)

    # Get the quantile
    quantile = normal_distribution.icdf(torch.tensor(tau)).to(device)

    model.log_lam_stack.pow(-.5)
    q_taus = model.mu_stack + quantile * model.log_lam_stack.exp().pow(-.5)

    

    return (q_taus > y.expand(q_taus.shape[0], q_taus.shape[1]).unsqueeze(2)).float().mean(dim=0)

def expected_calibration_error_nfe(y, model, device, samples=1000):
    ece = 0
    
    for s in range(samples):
        tau = torch.rand(1).to(device)
        ece += (p_obs_avg_nfe(y, model, tau, device) - tau).abs()
    
    return ece / samples


# performs linear interpolation (assuming even spacing)
def average_neighbors(tensor):
    left = tensor[:-1]  # Elements from the beginning to the second-to-last
    right = tensor[1:]  # Elements from the second to the last to the end
    average = (left + right) / 2
    return average

# observed data at x_obs, y_obs. would like linear interpolation to find y-vals
# at points x_inter

def update_gamma(criteria, model, print_pos=False):
    vals, inds = criteria.sort(0)
    best_gamma = model.gammas[inds][0]
    
    # move the worst performer to halfway between best and max
    if best_gamma == model.gammas.max():
        model.gammas[inds[2]] = (1+best_gamma) / 2
        if print_pos:
            print('max')
        
    # move the worst performer to halfway between best and min
    elif best_gamma == model.gammas.min():
        model.gammas[inds[2]] = best_gamma / 2
        if print_pos:
            print('min')
        
    # push both vals towards middle (best)
    else:
        model.gammas[inds[2]] = (model.gammas[inds[2]] + best_gamma) / 2
        model.gammas[inds[1]] = (model.gammas[inds[1]] + best_gamma) / 2
        if print_pos:
            print('mid')

def gam_rho_to_alpha_beta(gamma, rho):
    gamma = torch.tensor(gamma)
    rho = torch.tensor(rho)
    alpha = ((1-rho)*gamma) / rho
    beta = ((1-rho) * (1-gamma)) / rho
    
    return alpha, beta


def mult_approx_grad(point, model, device, eps=0.0001):
    repeated = point.repeat(point.shape[1], 1).to(device)
    
    eps_mat = (torch.eye(point.shape[1]) * eps).to(device)
    
    front_shift = repeated + eps_mat
    back_shift = repeated - eps_mat
    with torch.no_grad():
      mplus = model(front_shift)
      mminus = model(back_shift)
    
    mean_grad = (mplus['mean'] - mminus['mean']) / (2 * eps)
    prec_grad = (mplus['precision'] - mminus['precision']) / (2 * eps)
    
    return(mean_grad, prec_grad)


def stack_closed_form_mu(lam_stack, y, alpha, device):
    # eltwise mult (think reweight)
    n = lam_stack.size()[0]

    wt_y = ((lam_stack * y[:, None, :]).squeeze().T).to(device)
  
    lapmat = torch.diag_embed(torch.ones(n), -1)[0:n, 0:n] - 2 * torch.diag(torch.ones(n)) + torch.diag_embed(torch.ones(n), 1)[0:n, 0:n]

    # wrap the boundaries
    #lapmat[0, -1] = 1
    #lapmat[-1, 0] = 1  
    
    lapmat_stack = lapmat.repeat(lam_stack.shape[1], 1, 1).to(device)
    diag_lam = torch.diag_embed(lam_stack.squeeze(-1).T).to(device)
    

    
    # ret inv( diag(lam) + alpha/2 lapmat ) * lam * y // sign flip before the return because of the laplace matrix
    return (torch.linalg.solve(-(2*alpha)[:, None, None] * lapmat_stack + diag_lam,  wt_y)).transpose_(0, 1)[:, :, None]



def fit_field_theory(x, y, device, seed, gamma, rho, 
            max_epochs, lr, lr_min, lr_max, cycle_mode, base_model_path, step_size_up=1000, opt_scheme=None, noisy_y=False):
    print('re')


    epochs = max_epochs
    
    torch.manual_seed(seed)

    start_offset = 0

    dft = ParallelDFT(grid_discretization=x, gammas=gamma, rhos=rho, btw_pts=None, init_loc=0.).to(device)
    
    alphas, _ = gam_rho_to_alpha_beta(dft.gammas, dft.rhos)

    epochs = max_epochs

    if opt_scheme is None:
      optimizer = torch.optim.Adam([dft.mu_stack, dft.log_lam_stack], lr=lr)
      scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)

      losses = []

      for index in tqdm(range(epochs)):
          optimizer.zero_grad()
          if noisy_y:
             noise = torch.randn_like(y) * x.pow(2)
             y_noisy = y + noise
          else:
             y_noisy = y
          loss = dft.gamma_rho_const_noise_integral_loss(y_noisy)
          loss['loss'].sum().backward()
          optimizer.step()
          scheduler.step()

          if index % (epochs // 10) == 0:
              losses.append(loss)
              print(index / epochs)
    
    elif opt_scheme == "split":
      dft = ParallelDFT(grid_discretization=x, gammas=gamma, rhos=rho, btw_pts=None, init_loc=0., split_train=True, split_ratio=0.5).to(device)
      optimizer = torch.optim.Adam([dft.mu_stack, dft.log_lam_stack], lr=lr)
      scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)

      losses = []

      for index in tqdm(range(epochs)):
          if noisy_y:
             noise = torch.randn_like(y) * x.pow(2)
             y_noisy = y + noise
          else:
             y_noisy = y

          optimizer.zero_grad()
          loss = dft.gamma_rho_split_loss(y_noisy)
          loss['loss'].sum().backward(retain_graph=True)
          optimizer.step()
          scheduler.step()

          if index % (epochs // 10) == 0:
              losses.append(loss)
              print(index / epochs)
       

    elif opt_scheme == "closedmu":
      cycles = epochs // (step_size_up * 2)
      optimizer = torch.optim.Adam([dft.log_lam_stack], lr=lr)
      scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)

      losses = []

      index = 0 
      for c in tqdm(range(cycles)):
          if noisy_y:
             noise = torch.randn_like(y) * x.pow(2)
             y_noisy = y + noise
          else:
             y_noisy = y
          # update mu
          print(dft.log_lam_stack.exp().shape)
          print(y.shape)
          print(alphas.shape)
          dft.mu_stack.data = stack_closed_form_mu(dft.log_lam_stack.exp(), y_noisy, alphas, device)


          for j in range(2 * step_size_up):
          # update log lambda
            optimizer.zero_grad()
            loss = dft.gamma_rho_integral_loss(y_noisy)
            loss['loss'].sum().backward()
            optimizer.step()
            scheduler.step()

            if index % (epochs // 10) == 0:
                losses.append(loss)
                print(index / epochs)
            
            index += 1
        


    pickle.dump(dft, open(base_model_path + str(index + start_offset) + '_parallel_dft.p', 'wb'))
    pickle.dump(losses, open(base_model_path + str(index + start_offset) + '_parallel_loss_stats.p', 'wb'))

    return dft, losses
    
    
    

def run_exp(x, y, device, seed, gammas, rhos, 
            n_feature, n_output, act_func, prec_act_func, max_epochs,
            lr, momentum, lr_min, lr_max, cycle_mode, base_model_path, per_param_loss=True,
            pre_trained_path=None, start_factor=0.05, total_iters=1000,
            hidden_size=128, hidden_layers=2, step_size_up=1000, clip=1000,
            approx_pts=1000, approx_ball=0.001, mean_warmup=20000, mean_log=True, plots=True, 
            beta_nll=False, diag=False, var_param=False):
    print('re')
    fail_it = -1

    torch.manual_seed(seed)

    start_offset = 0
    keep_keys = ['loss', 'losses', 'mse', 'log_precision', 'raw_mean_reg', 'raw_prec_reg']

    hidden_sizes = [hidden_size for _ in range(hidden_layers)]

    if pre_trained_path is None:
        ppm = ParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss, var_param=var_param, diag=diag)
    else:
        ppm = ParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss, var_param=var_param, diag=diag)
        checkpoint = torch.load(pre_trained_path)
        ppm.load_state_dict(checkpoint['model_state_dict'])
        ppm.train()

        start_offset = checkpoint['epoch'] + 1 # correct for off by one

    ppm = ppm.to(device)
    
    failed_models = [[] for _ in range(ppm.num_models)]

    epochs = max_epochs

    train_stats = []
    grad_ints = []

    
    opt = torch.optim.Adam(ppm.parameters(), lr=lr_max)

    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)


    if pre_trained_path is not None:
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
                
    dense_x = torch.linspace(x.min(), x.max(), 1000)

    dense_x = dense_x.to(device)

    gw = dense_x[1]-dense_x[0]

    for i in tqdm(range(epochs)):
        opt.zero_grad()
        
        if i < (mean_warmup):
          if beta_nll:
            stats = ppm.beta_nll_loss(y, ppm(x))
          else:
            stats = ppm.mean_gam_rho_loss(y, ppm(x))#
        else:
          if beta_nll:
            stats = ppm.beta_nll_loss(y, ppm(x))
          else:
            stats = ppm.gam_rho_loss(y, ppm(x)) 



        #dense_vals = ppm.grad_pen(ppm(dense_x), gw)
        #grad_int = ppm.approx_2_int(approx_pts, approx_ball)
        loss =  stats['loss'] #stats['likelihood'] + (grad_int['mint'] * ppm.gammas + grad_int['pint'] * ppm.rhos).sum(0) #

          # log stats every 2%
        if i % (epochs // 50) == 0:
          sub_stats = {key: stats[key] for key in keep_keys}

          train_stats.append(sub_stats)
          #train_stats.append(stats)

        if i == (mean_warmup-1) and mean_log:
            plot_dense_x = torch.linspace(x.min().item(), x.max().item(), 300)[:, None]
            plot_dense_x = plot_dense_x.to(device)
            if plots:
              plot_parallel_model(ppm=ppm, x=x, y=y, stats=train_stats, iteration=i + start_offset, dense_x=plot_dense_x, path=base_model_path)
            
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(ppm, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
            pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

            PATH = base_model_path + 'full_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': ppm.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)


        # early termination if breaks
        if loss.isnan() or loss.isinf():
            fail_it = i + start_offset
            
            for j, l in enumerate(stats['losses']):
                if l.isnan() or l.isinf():
                    # record which model and when 
                    failed_models[j].append(i)

            PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': ppm.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)
            

            break
                    

        loss.backward()

        torch.nn.utils.clip_grad_norm_(ppm.parameters(), clip)
        opt.step()
        scheduler.step()



    pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
    pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
    pickle.dump(ppm, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
    pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

    PATH = base_model_path + 'full_checkpoint_epochs_' + str(i + start_offset) + '.pt'

    plot_dense_x = torch.linspace(x.min().item(), x.max().item(), 300)[:, None]
    plot_dense_x = plot_dense_x.to(device)
    if plots:
      plot_parallel_model(ppm=ppm, x=x, y=y, stats=train_stats, iteration=i + start_offset, dense_x=plot_dense_x, path=base_model_path)   


    print(PATH)
    torch.save({
                'epoch': i + start_offset,
                'model_state_dict': ppm.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, PATH)
    
    return fail_it, ppm


def run_split_exp(x, y, x_mean, y_mean, x_prec, y_prec, device, seed, gammas, rhos, 
            n_feature, n_output, act_func, prec_act_func, max_epochs, lr_min, lr_max, cycle_mode, base_model_path, per_param_loss=True,
            pre_trained_path=None, split_ratio=0.5,
            hidden_size=128, hidden_layers=2, step_size_up=1000, clip=1000, mean_warmup=20000, mean_log=True, plots=True, continuous_shuffling=False, aug=False, laplace=False):
    print('rse')
    fail_it = -1

    keep_keys = ['loss', 'losses', 'mse', 'log_precision', 'raw_mean_reg', 'raw_prec_reg']
    torch.manual_seed(seed)

    start_offset = 0

    hidden_sizes = [hidden_size for _ in range(hidden_layers)]

    if pre_trained_path is None:
        spff = SplitParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss)
    else:
        spff = SplitParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss)
        checkpoint = torch.load(pre_trained_path)
        spff.load_state_dict(checkpoint['model_state_dict'])
        spff.train()

        start_offset = checkpoint['epoch'] + 1 # correct for off by one

    spff = spff.to(device)
    
    failed_models = [[] for _ in range(spff.num_models)]

    epochs = max_epochs

    train_stats = []
    grad_ints = []

    
    opt = torch.optim.Adam(spff.parameters(), lr=lr_max)

    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)


    if pre_trained_path is not None:
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    min_x = torch.stack((x_mean, x_prec)).min()
    max_x = torch.stack((x_mean, x_prec)).max()
    dense_x = torch.linspace(min_x, max_x, 1000)

    dense_x = dense_x.to(device)


    for i in tqdm(range(epochs)):
        opt.zero_grad()

        if aug:
           noise = torch.randn_like(x)/5
           x_mean = x #+ noise #torch.rand_like(x)/4
           x_prec = x #+ noise #torch.rand_like(x)/4

           y_mean = y + noise
           y_prec = y + noise

        if continuous_shuffling:
          torch.manual_seed(seed + i)
          inds = torch.randperm(len(x))

          cut_point = int(len(x) * 0.5)

          mean_inds = inds[cut_point:]
          prec_inds = inds[:cut_point]

          x_mean = x[mean_inds]
          y_mean = y[mean_inds]

          x_prec = x[prec_inds]
          y_prec = y[prec_inds]



        if i < (mean_warmup):
          # train the warmup on everything
          if not laplace:
            stats = spff.mean_gam_rho_loss(y_mean, spff(x_mean, x_prec)['mean_dict'])
          else:
             stats = spff.mean_gam_rho_loss_laplace(y_mean, spff(x_mean, x_prec)['mean_dict'])
        else:
          full_dict = spff(x_mean, x_prec)

          if not laplace:
            stats = spff.gam_rho_split_loss(y_mean, y_prec, full_dict['mean_dict'], full_dict['prec_dict'])
          else:
            stats = spff.gam_rho_split_loss_laplace(y_mean, y_prec, full_dict['mean_dict'], full_dict['prec_dict'])

        loss =  stats['loss'] #stats['likelihood'] + (grad_int['mint'] * ppm.gammas + grad_int['pint'] * ppm.rhos).sum(0) #

          # log stats every 2%
        if i % (epochs // 50) == 0:
          sub_stats = {key: stats[key] for key in keep_keys}

          train_stats.append(sub_stats)

        if i == (mean_warmup-1) and mean_log:
            plot_dense_x = torch.linspace(min_x-.2, max_x+.2, 200)[:, None]
            plot_dense_x = plot_dense_x.to(device)
            if plots:
              if not laplace:
                plot_split_model(spff=spff, x_m=x_mean, y_m=y_mean, x_p=x_prec, y_p=y_prec, stats=train_stats, iteration=i + start_offset, dense_x=plot_dense_x, bound_sd_y=5., path=base_model_path)
              else:
                plot_split_model(spff=spff, x_m=x_mean, y_m=y_mean, x_p=x_prec, y_p=y_prec, stats=train_stats, iteration=i + start_offset, dense_x=plot_dense_x, bound_sd_y=5., path=base_model_path, laplace=True)
            
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(spff, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
            pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

            PATH = base_model_path + 'full_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': spff.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)


        # early termination if breaks
        if loss.isnan() or loss.isinf():
            fail_it = i + start_offset
            
            for j, l in enumerate(stats['losses']):
                if l.isnan() or l.isinf():
                    # record which model and when 
                    failed_models[j].append(i)

            PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': spff.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)
            

            break
                    

        loss.backward()

        torch.nn.utils.clip_grad_norm_(spff.parameters(), clip)
        opt.step()
        scheduler.step()



    pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
    pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
    pickle.dump(spff, open(base_model_path + str(i + start_offset) + '_split_model.p', 'wb'))
    pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

    PATH = base_model_path + 'full_checkpoint_epochs_' + str(i + start_offset) + '.pt'

    plot_dense_x = torch.linspace(min_x, max_x, 300)[:, None]
    plot_dense_x = plot_dense_x.to(device)
    if plots:
      if not laplace:
        plot_split_model(spff=spff, x_m=x_mean, y_m=y_mean, x_p=x_prec, y_p=y_prec, stats=train_stats, iteration=i + start_offset, dense_x=plot_dense_x, path=base_model_path)
      else:
        plot_split_model(spff=spff, x_m=x_mean, y_m=y_mean, x_p=x_prec, y_p=y_prec, stats=train_stats, iteration=i + start_offset, dense_x=plot_dense_x, bound_sd_y=5., path=base_model_path, laplace=True)

    print(PATH)
    torch.save({
                'epoch': i + start_offset,
                'model_state_dict': spff.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, PATH)
    
    return fail_it, spff


def run_uci_exp(x, y, device, seed, gammas, rhos, 
            n_feature, n_output, act_func, prec_act_func, max_epochs, lr_min, lr_max, 
            cycle_mode, base_model_path, per_param_loss=True, pre_trained_path=None,
            hidden_size=128, hidden_layers=2, step_size_up=1000, clip=1000, batch_size=None, mean_warmup=1000, 
            beta_nll=False, var_param=False, diag=False):
    print('re')
    fail_it = -1

    torch.manual_seed(seed)

    start_offset = 0

    hidden_sizes = [hidden_size for _ in range(hidden_layers)]

    if pre_trained_path is None:
        ppm = ParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss, var_param=var_param, diag=diag)
    else:
        ppm = ParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss, var_param=var_param, diag=diag)
        checkpoint = torch.load(pre_trained_path)
        ppm.load_state_dict(checkpoint['model_state_dict'])
        ppm.train()

        start_offset = checkpoint['epoch'] + 1 # correct for off by one

    ppm = ppm.to(device)
    
    failed_models = [[] for _ in range(ppm.num_models)]

    epochs = max_epochs

    train_stats = []
    grad_ints = []

    keep_keys = ['loss', 'losses', 'mse', 'log_precision', 'raw_mean_reg', 'raw_prec_reg']

    opt = torch.optim.Adam(ppm.parameters(), lr=lr_max)

    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)


    if pre_trained_path is not None:
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
                
    dense_x = torch.linspace(x.min(), x.max(), 1000)

    dense_x = dense_x.to(device)

    gw = dense_x[1]-dense_x[0]

    if batch_size is None:
      for i in tqdm(range(epochs)):
          opt.zero_grad()
          
          if i < mean_warmup:
            if beta_nll:
              stats = ppm.beta_nll_loss(y, ppm(x))
            else:
              stats = ppm.mean_gam_rho_loss(y, ppm(x))
          else:
            if beta_nll:
              stats = ppm.beta_nll_loss(y, ppm(x))
            else:
              stats = ppm.gam_rho_loss(y, ppm(x))



          loss =  stats['loss'] #stats['likelihood'] + (grad_int['mint'] * ppm.gammas + grad_int['pint'] * ppm.rhos).sum(0) #

            # log stats every 2%
          if i % (epochs // 50) == 0:
              sub_stats = {key: stats[key] for key in keep_keys}

              train_stats.append(sub_stats)
              #grad_ints.append(grad_int)

          if i == mean_warmup-1:
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(ppm, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
            pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

            PATH = base_model_path + 'half_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': ppm.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)


          # early termination if breaks
          if loss.isnan() or loss.isinf():
              fail_it = i + start_offset
              
              for j, l in enumerate(stats['losses']):
                  if l.isnan() or l.isinf():
                      # record which model and when 
                      failed_models[j].append(i)

              PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

              torch.save({
                          'epoch': i + start_offset,
                          'model_state_dict': ppm.state_dict(),
                          'optimizer_state_dict': opt.state_dict(),
                          'scheduler_state_dict': scheduler.state_dict(),
                          }, PATH)
              

              break
                      

          loss.backward()

          torch.nn.utils.clip_grad_norm_(ppm.parameters(), clip)
          opt.step()
          scheduler.step()
    else:
      num_batches = x.shape[0] // batch_size
      running_losses = []
      for i in tqdm(range(epochs)):
          running_loss = 0
          for b in range(num_batches):
            start_ind = b * batch_size
            end_ind = min((b + 1) * batch_size, x.shape[0])

            batch_x = x[start_ind:end_ind, :]
            batch_y = y[start_ind:end_ind]

            opt.zero_grad()
            
            if i < (epochs / 2):
              if beta_nll:
                stats = ppm.beta_nll_loss(batch_y, ppm(batch_x))
              else:
                stats = ppm.mean_gam_rho_loss(batch_y, ppm(batch_x))
            else:
              if beta_nll:
                stats = ppm.beta_nll_loss(batch_y, ppm(batch_x))
              else:
                stats = ppm.gam_rho_loss(batch_y, ppm(batch_x))

            loss =  stats['loss'] #stats['likelihood'] + (grad_int['mint'] * ppm.gammas + grad_int['pint'] * ppm.rhos).sum(0) #

            # early termination if breaks
            if loss.isnan() or loss.isinf():
                fail_it = i + start_offset
                
                for j, l in enumerate(stats['losses']):
                    if l.isnan() or l.isinf():
                        # record which model and when 
                        failed_models[j].append(i)

                PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

                torch.save({
                            'epoch': i + start_offset,
                            'model_state_dict': ppm.state_dict(),
                            'optimizer_state_dict': opt.state_dict(),
                            'scheduler_state_dict': scheduler.state_dict(),
                            }, PATH)
                

                break
                        
            loss.backward()

            torch.nn.utils.clip_grad_norm_(ppm.parameters(), clip)
            opt.step()
            scheduler.step()

            running_loss += loss.item()

          # log stats every 2%
          log_freq = (epochs // 50) if (epochs > 50)  else 1

          if i % log_freq == 0:
              sub_stats = {key: stats[key] for key in keep_keys}

              train_stats.append(sub_stats)
              running_losses.append(running_loss)

          if i == (epochs // 2)-1:
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(ppm, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
            pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

            PATH = base_model_path + 'half_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': ppm.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)

      pickle.dump(running_losses, open(base_model_path + str(i + start_offset) + '_running_losses.p', 'wb'))
        

    pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
    pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
    pickle.dump(ppm, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
    pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

    PATH = base_model_path + 'full_checkpoint_epochs_' + str(i + start_offset) + '.pt'



    print(PATH)
    torch.save({
                'epoch': i + start_offset,
                'model_state_dict': ppm.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, PATH)
    
    return fail_it, ppm



def run_split_uci_exp(x_mean, y_mean, x_prec, y_prec, device, seed, gammas, rhos, 
            n_feature, n_output, act_func, prec_act_func, max_epochs, lr_min, lr_max, 
            cycle_mode, base_model_path, per_param_loss=True, pre_trained_path=None,
            hidden_size=128, hidden_layers=2, step_size_up=1000, clip=1000, batch_size=None, mean_warmup=1000):
    print('re')
    fail_it = -1
    keep_keys = ['loss', 'losses', 'mse', 'log_precision', 'raw_mean_reg', 'raw_prec_reg']
    torch.manual_seed(seed)

    start_offset = 0

    hidden_sizes = [hidden_size for _ in range(hidden_layers)]

    if pre_trained_path is None:
        spff = SplitParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss)
    else:
        spff = SplitParallelFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gammas, rhos=rhos, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss)
        checkpoint = torch.load(pre_trained_path)
        spff.load_state_dict(checkpoint['model_state_dict'])
        spff.train()

        start_offset = checkpoint['epoch'] + 1 # correct for off by one

    spff = spff.to(device)
    
    failed_models = [[] for _ in range(spff.num_models)]

    epochs = max_epochs

    train_stats = []
    grad_ints = []

    
    opt = torch.optim.Adam(spff.parameters(), lr=lr_max)

    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)


    if pre_trained_path is not None:
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
                

    if batch_size is None:
      for i in tqdm(range(epochs)):
          opt.zero_grad()
          
          if i < mean_warmup:
            stats = spff.mean_gam_rho_loss(y_mean, spff(x_mean, x_prec)['mean_dict'])
          else:
            full_dict = spff(x_mean, x_prec)
            stats = spff.gam_rho_split_loss(y_mean, y_prec, full_dict['mean_dict'], full_dict['prec_dict'])

          loss =  stats['loss'] 
            # log stats every 2%
          if i % (epochs // 50) == 0:
            sub_stats = {key: stats[key] for key in keep_keys}

            train_stats.append(sub_stats)
              #grad_ints.append(grad_int)

          if i == mean_warmup-1:
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(spff, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
            pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

            PATH = base_model_path + 'half_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': spff.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)


          # early termination if breaks
          if loss.isnan() or loss.isinf():
              fail_it = i + start_offset
              
              for j, l in enumerate(stats['losses']):
                  if l.isnan() or l.isinf():
                      # record which model and when 
                      failed_models[j].append(i)

              PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

              torch.save({
                          'epoch': i + start_offset,
                          'model_state_dict': spff.state_dict(),
                          'optimizer_state_dict': opt.state_dict(),
                          'scheduler_state_dict': scheduler.state_dict(),
                          }, PATH)
              

              break
                      

          loss.backward()

          torch.nn.utils.clip_grad_norm_(spff.parameters(), clip)
          opt.step()
          scheduler.step()
    else:
      num_batches = x_mean.shape[0] // batch_size

      running_losses = []
      for i in tqdm(range(epochs)):
          running_loss = 0

          ## rework this logic to be *slightly* more clever
          for b in range(num_batches):
            start_ind = b * batch_size
            end_ind = min((b + 1) * batch_size, x_mean.shape[0])

            batch_x_m = x_mean[start_ind:end_ind, :]
            batch_y_m = y_mean[start_ind:end_ind]

            batch_x_p = x_prec[start_ind:end_ind, :]
            batch_y_p = y_prec[start_ind:end_ind]

            opt.zero_grad()
            
            if i < (epochs / 2):
              stats = spff.mean_gam_rho_loss(batch_y_m, spff(batch_x_m, batch_x_p)['mean_dict'])
            else:
              stats = spff.gam_rho_split_loss(batch_y_m, batch_y_p, spff(batch_x_m, batch_x_p))

            loss =  stats['loss'] 

            # early termination if breaks
            if loss.isnan() or loss.isinf():
                fail_it = i + start_offset
                
                for j, l in enumerate(stats['losses']):
                    if l.isnan() or l.isinf():
                        # record which model and when 
                        failed_models[j].append(i)

                PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

                torch.save({
                            'epoch': i + start_offset,
                            'model_state_dict': spff.state_dict(),
                            'optimizer_state_dict': opt.state_dict(),
                            'scheduler_state_dict': scheduler.state_dict(),
                            }, PATH)
                

                break
                        
            loss.backward()

            torch.nn.utils.clip_grad_norm_(spff.parameters(), clip)
            opt.step()
            scheduler.step()

            running_loss += loss.item()

          # log stats every 2%
          if i % (epochs // 50) == 0:
            sub_stats = {key: stats[key] for key in keep_keys}

            train_stats.append(sub_stats)
            running_losses.append(running_loss)

          if i == (epochs // 2)-1:
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(spff, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
            pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

            PATH = base_model_path + 'half_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': spff.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)

      pickle.dump(running_losses, open(base_model_path + str(i + start_offset) + '_running_losses.p', 'wb'))
        

    pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
    pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
    pickle.dump(spff, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))
    pickle.dump(failed_models, open(base_model_path + str(i + start_offset) + '_failed_models.p', 'wb'))

    PATH = base_model_path + 'full_checkpoint_epochs_' + str(i + start_offset) + '.pt'



    print(PATH)
    torch.save({
                'epoch': i + start_offset,
                'model_state_dict': spff.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, PATH)
    
    return fail_it, spff

def run_single_exp(x, y, device, seed, gamma, rho, 
            n_feature, n_output, act_func, prec_act_func, max_epochs, lr_min, lr_max, 
            cycle_mode, base_model_path, per_param_loss=True, pre_trained_path=None,
            hidden_size=128, hidden_layers=2, step_size_up=1000, clip=1000, batch_size=None, mean_warmup=1000):
    print('rse')
    fail_it = -1
    keep_keys = ['loss', 'losses', 'mse', 'log_precision', 'raw_mean_reg', 'raw_prec_reg']
    torch.manual_seed(seed)

    start_offset = 0

    hidden_sizes = [hidden_size for _ in range(hidden_layers)]

    if pre_trained_path is None:
        sff = SingleFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gamma, rhos=rho, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss)
    else:
        sff = SingleFF(n_feature, n_output, hidden_sizes=hidden_sizes, gammas=gamma, rhos=rho, activation_func=act_func, precision_activation_func=prec_act_func, per_param_loss=per_param_loss)
        checkpoint = torch.load(pre_trained_path)
        sff.load_state_dict(checkpoint['model_state_dict'])
        sff.train()

        start_offset = checkpoint['epoch'] + 1 # correct for off by one

    sff = sff.to(device)
    

    epochs = max_epochs

    train_stats = []
    grad_ints = []

    
    opt = torch.optim.Adam(sff.parameters(), lr=lr_max)

    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=lr_min, max_lr=lr_max, mode=cycle_mode, cycle_momentum=False, step_size_up=step_size_up)


    if pre_trained_path is not None:
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
                
    dense_x = torch.linspace(x.min(), x.max(), 1000)

    dense_x = dense_x.to(device)

    gw = dense_x[1]-dense_x[0]

    if batch_size is None:
      for i in tqdm(range(epochs)):
          opt.zero_grad()
          
          if i < mean_warmup:
            stats = sff.mean_gam_rho_loss(y, sff(x))
          else:
            stats = sff.gam_rho_loss(y, sff(x))



          loss =  stats['loss'] #stats['likelihood'] + (grad_int['mint'] * ppm.gammas + grad_int['pint'] * ppm.rhos).sum(0) #

            # log stats every 2%
          if i % (epochs // 50) == 0:
            sub_stats = {key: stats[key] for key in keep_keys}

            train_stats.append(sub_stats)
              #grad_ints.append(grad_int)

          if i == mean_warmup-1:
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(sff, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))

            PATH = base_model_path + 'half_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': sff.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)


          # early termination if breaks
          if loss.isnan() or loss.isinf():
              fail_it = i + start_offset
              

              PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

              torch.save({
                          'epoch': i + start_offset,
                          'model_state_dict': sff.state_dict(),
                          'optimizer_state_dict': opt.state_dict(),
                          'scheduler_state_dict': scheduler.state_dict(),
                          }, PATH)
              

              break
                      

          loss.backward()

          torch.nn.utils.clip_grad_norm_(sff.parameters(), clip)
          opt.step()
          scheduler.step()
    else:
      num_batches = x.shape[0] // batch_size
      running_losses = []
      for i in tqdm(range(epochs)):
          running_loss = 0
          for b in range(num_batches):
            start_ind = b * batch_size
            end_ind = min((b + 1) * batch_size, x.shape[0])

            batch_x = x[start_ind:end_ind, :]
            batch_y = y[start_ind:end_ind]

            opt.zero_grad()
            
            if i < (epochs / 2):
              
              stats = sff.mean_gam_rho_loss(batch_y, sff(batch_x))
            else:
              stats = sff.gam_rho_loss(batch_y, sff(batch_x))

            loss =  stats['loss'] #stats['likelihood'] + (grad_int['mint'] * ppm.gammas + grad_int['pint'] * ppm.rhos).sum(0) #

            # early termination if breaks
            if loss.isnan() or loss.isinf():
                fail_it = i + start_offset
                


                PATH = base_model_path + 'checkpoints_broken/checkpoint_' + str(i + start_offset) + '.pt'

                torch.save({
                            'epoch': i + start_offset,
                            'model_state_dict': sff.state_dict(),
                            'optimizer_state_dict': opt.state_dict(),
                            'scheduler_state_dict': scheduler.state_dict(),
                            }, PATH)
                

                break
                        
            loss.backward()

            torch.nn.utils.clip_grad_norm_(sff.parameters(), clip)
            opt.step()
            scheduler.step()

            running_loss += loss.item()

          # log stats every 2%
          log_freq = (epochs // 50) if (epochs > 50)  else 1

          if i % log_freq == 0:
            sub_stats = {key: stats[key] for key in keep_keys}

            train_stats.append(sub_stats)
            running_losses.append(running_loss)

          if i == (epochs // 2)-1:
            pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
            pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
            pickle.dump(sff, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))

            PATH = base_model_path + 'half_checkpoint_epochs_' + str(i + start_offset) + '.pt'

            torch.save({
                        'epoch': i + start_offset,
                        'model_state_dict': sff.state_dict(),
                        'optimizer_state_dict': opt.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        }, PATH)

      pickle.dump(running_losses, open(base_model_path + str(i + start_offset) + '_running_losses.p', 'wb'))
        

    pickle.dump(grad_ints, open(base_model_path + str(i + start_offset) + '_grad_ints.p', 'wb'))
    pickle.dump(train_stats, open(base_model_path + str(i + start_offset) + '_train_stats.p', 'wb'))
    pickle.dump(sff, open(base_model_path + str(i + start_offset) + '_parallel_model.p', 'wb'))

    PATH = base_model_path + 'full_checkpoint_epochs_' + str(i + start_offset) + '.pt'



    print(PATH)
    torch.save({
                'epoch': i + start_offset,
                'model_state_dict': sff.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                }, PATH)
    
    return fail_it, sff


def grad_est(model, pt, eps, device):
    x_pre = pt - eps
    x_post = pt + eps
    
    pred_pre = model(x_pre)
    pred_post = model(x_post)
    
    return {"mgrad": (pred_post['mean'] - pred_pre['mean']) / (2 * eps), "pgrad": (pred_post['precision'] - pred_pre['precision']) / (2 * eps)}


def plot_result(ax, x, y, true_mu, x_plot, mup, sdp):
  ax.scatter(x.squeeze().cpu().detach(), y.squeeze().cpu(), color="tab:blue", marker='.')
  ax.plot(x_plot.squeeze().cpu().detach(), true_mu.squeeze().cpu(), '--', color="black")

  #axa = ax.twinx()
  #if m is not None:
  #  ax.title('mn: ' + str(m.mean_prior_std) + '; std: ' + str(m.var_prior_std))  
  #ax.plot(x_plot.squeeze().detach(), tup[0].squeeze().detach(), color="tab:blue")
  #ax.plot(x_plot.squeeze().detach(), tup[1].squeeze().detach(), color="tab:orange")
  for mu_, std_ in zip(mup, sdp):
    ax.plot(x_plot.squeeze().cpu().detach(), mu_.squeeze().cpu().detach(), color="tab:orange")
    #axa.plot(x_plot.squeeze().detach(), std_.squeeze().detach(), color="tab:orange")


def plot_sd_res(ax, x, res, x_plot, sd_plot):
  ax.scatter(x.squeeze().cpu().detach(), res.squeeze().cpu().detach(), color="tab:blue", marker='.')
  
  for sd in sd_plot:
    ax.plot(x_plot.squeeze().cpu().detach(), sd.squeeze().cpu().detach(), color="tab:orange")

def plot_split_model(spff, x_m, y_m, x_p, y_p, stats, iteration, dense_x=None, path=None, show_plots=False, bound_mn_y=None, bound_sd_y=None, laplace=False):
  print('entered psm')
  gammas = spff.unique_gammas
  rhos = spff.unique_rhos

  # combine the (x,y) pairs from mean/precision into a single set
  x = torch.stack((x_m, x_p)).flatten()
  y = torch.stack((y_m, y_p)).flatten()

  pts_x, inds = x.sort()
  
  pts_x = pts_x[:, None]
  y = y[inds]

  cts_x = pts_x
  if dense_x is not None:
     cts_x = dense_x.sort()[0]
    
  plot_loss = stats is not None

  if plot_loss:
    loss_grid = torch.stack([l['losses'] for l in stats]).view(len(stats), len(gammas), len(rhos)).cpu().detach() 

  a_labs = spff.gammas.view(len(gammas), len(rhos)).cpu().detach() 
  b_labs = spff.rhos.view(len(gammas), len(rhos)).cpu().detach() 


  size=4

  fig_res_sds, axs_res_sds = plt.subplots(len(gammas), len(rhos), figsize=(len(gammas)*size, len(rhos)*size), sharex=True, sharey=False)
  fig_res_sds.tight_layout(pad=2.75)

  fig_mn, axs_mn = plt.subplots(len(gammas), len(rhos), figsize=(len(gammas)*size, len(rhos)*size), sharex=True, sharey=False)
  fig_mn.tight_layout(pad=2.75)

  if plot_loss:
    fig_loss, axs_loss = plt.subplots(len(gammas), len(rhos), figsize=(len(gammas)*size, len(rhos)*size), sharex=True, sharey=False)
    fig_loss.tight_layout(pad=2.75)
    axs_dict = {"res_sds":axs_res_sds, 
              "mns": axs_mn,
              "loss": axs_loss}
  else:
    axs_dict = {"res_sds":axs_res_sds, 
              "mns": axs_mn}


  # complete pass, strip out means, precisions (transformed)
  with torch.no_grad():
    plot_vals = spff(pts_x, pts_x)
    cts_plot_vals = spff(cts_x, cts_x)

    full_dict = spff(x_m, x_p)
    if not laplace:
      loss = spff.gam_rho_split_loss(y_m, y_p, full_dict['mean_dict'], full_dict['prec_dict'])
    else:
      loss = spff.gam_rho_split_loss_laplace(y_m, y_p, full_dict['mean_dict'], full_dict['prec_dict'])

  # the mean/prec dict calls here don't matter since it's on the same set of points
  pm_mns = plot_vals['mean_dict']['mean']
  pm_mns = pm_mns.view((len(pts_x), len(gammas), len(rhos))).cpu().detach() 

  if not laplace:
    pm_sds = plot_vals['prec_dict']['precision']
    pm_sds = pm_sds.view((len(pts_x), len(gammas), len(rhos))).pow(-.5).cpu().detach()
  else:
    pm_sds = plot_vals['prec_dict']['precision']
    pm_sds = pm_sds.view((len(pts_x), len(gammas), len(rhos))).pow(-1).cpu().detach() * math.sqrt(2)
  

  cts_pm_mns = cts_plot_vals['mean_dict']['mean']
  cts_pm_mns = cts_pm_mns.view((len(cts_x), len(gammas), len(rhos))).cpu().detach() 

  if not laplace:
    cts_pm_sds = cts_plot_vals['prec_dict']['precision']
    cts_pm_sds = cts_pm_sds.view((len(cts_x), len(gammas), len(rhos))).pow(-.5).cpu().detach() 
  else:
    cts_pm_sds = cts_plot_vals['prec_dict']['precision']
    cts_pm_sds = cts_pm_sds.view((len(cts_x), len(gammas), len(rhos))).pow(-1).cpu().detach()  * math.sqrt(2)


  resids_m = loss['mean_loss']['residuals'].view((len(x_m), len(gammas), len(rhos))).cpu()
  resids_p = loss['prec_loss']['residuals'].view((len(x_p), len(gammas), len(rhos))).cpu()  

  x_plot = x.cpu().detach().flatten()
  cts_x_plot = cts_x.cpu().detach().flatten()
  y_plot = y.cpu().detach().flatten()

  # each value of mean/prec ratio
  for i in range(len(gammas)):

      # each value of likelihood weighting
      for j in range(len(rhos)):
          

          if j == 0:
              for _, axs in axs_dict.items():
                  axs[i][j].set_ylabel(r"gamma: {:.3E}".format(a_labs[i][0]))

          if i == len(gammas)-1:
              for _, axs in axs_dict.items():
                  axs[i][j].set_xlabel(r"rho: {:.3E}".format(b_labs[0][j]))

          
          mns = pm_mns[:, i, j]
          sds = pm_sds[:, i, j]

          cts_mns = cts_pm_mns[:, i, j]
          cts_sds = cts_pm_sds[:, i, j]

          axs_dict['mns'][i][j].scatter(x_m.cpu(), y_m.cpu(), marker='+')
          axs_dict['mns'][i][j].scatter(x_p.cpu(), y_p.cpu(), marker='.')
          #axs_dict['mns'][i][j].plot(cts_x_plot, cts_mns, c='orange')


          ci = cts_sds 

          if not laplace:
            axs_dict['mns'][i][j].fill_between(cts_x_plot.cpu().flatten(), (cts_mns-ci).flatten(), (cts_mns+ci).flatten(), color='b', alpha=.2)
            axs_dict['mns'][i][j].fill_between(cts_x_plot.cpu().flatten(), (cts_mns-2*ci).flatten(), (cts_mns+2*ci).flatten(), color='b', alpha=.1)
          else:
            l_scale = math.log(5)
            axs_dict['mns'][i][j].fill_between(cts_x_plot.cpu().flatten(), (cts_mns-l_scale*ci).flatten(), (cts_mns+l_scale*ci).flatten(), color='b', alpha=.1)
             
          axs_dict['mns'][i][j].set_ylim(-5, 5)
          axs_dict['mns'][i][j].plot(cts_x_plot, cts_mns)

          if bound_mn_y is not None:
            axs_dict['mns'][i][j].set_ylim(-bound_mn_y, bound_mn_y)
          
          
          axs_dict['res_sds'][i][j].scatter(x_m.cpu(), resids_m[:, i, j].abs().cpu(), marker='+')
          axs_dict['res_sds'][i][j].scatter(x_p.cpu(), resids_p[:, i, j].abs().cpu(), marker='.')

          axs_dict['res_sds'][i][j].plot(cts_x_plot, cts_sds, c='orange')
          
          if plot_loss:
            axs_dict['loss'][i][j].plot(loss_grid[:, i, j])

          if bound_sd_y is not None:
            axs_dict['res_sds'][i][j].set_ylim(-.1, bound_sd_y)
          
      

      print(i)
      


  fig_res_sds.suptitle('Synthetic: Pred SDs over Residuals ' + str(iteration), size=50)
  fig_res_sds.subplots_adjust(top=0.95)


  fig_mn.suptitle('Synthetic: Means ' + str(iteration), size=50)
  fig_mn.subplots_adjust(top=0.95)

  if plot_loss:
    fig_loss.suptitle('losses ' + str(iteration), size=50)
    fig_loss.subplots_adjust(top=0.95)
  
  if path is not None: 
    fig_mn.savefig(path +'/plots/mean_' + str(iteration) + '.png')
    fig_res_sds.savefig(path +'/plots/res_sd_' + str(iteration) + '.png')
    if plot_loss:
      fig_loss.savefig(path +'/plots/loss_' + str(iteration) + '.png')
  
  if show_plots:
    plt.show()

  plt.close('all')   


def plot_parallel_model(ppm, x, y, stats, iteration, dense_x=None, path=None, show_plots=False, bound_mn_y=None, bound_sd_y=None, laplace=False):
  gammas = ppm.unique_gammas
  rhos = ppm.unique_rhos

  pts_x = x.sort()[0]
  cts_x = pts_x
  if dense_x is not None:
     cts_x = dense_x.sort()[0]
    
  plot_loss = stats is not None

  if plot_loss:
    loss_grid = torch.stack([l['losses'] for l in stats]).view(len(stats), len(gammas), len(rhos)).cpu().detach() 

  a_labs = ppm.gammas.view(len(gammas), len(rhos)).cpu().detach() 
  b_labs = ppm.rhos.view(len(gammas), len(rhos)).cpu().detach() 


  size=4

  fig_res_sds, axs_res_sds = plt.subplots(len(gammas), len(rhos), figsize=(len(gammas)*size, len(rhos)*size), sharex=True, sharey=False)
  fig_res_sds.tight_layout(pad=2.75)

  fig_mn, axs_mn = plt.subplots(len(gammas), len(rhos), figsize=(len(gammas)*size, len(rhos)*size), sharex=True, sharey=False)
  fig_mn.tight_layout(pad=2.75)

  if plot_loss:
    fig_loss, axs_loss = plt.subplots(len(gammas), len(rhos), figsize=(len(gammas)*size, len(rhos)*size), sharex=True, sharey=False)
    fig_loss.tight_layout(pad=2.75)
    axs_dict = {"res_sds":axs_res_sds, 
              "mns": axs_mn,
              "loss": axs_loss}
  else:
    axs_dict = {"res_sds":axs_res_sds, 
              "mns": axs_mn}


  # complete pass, strip out means, precisions (transformed)
  plot_vals = ppm(pts_x)
  pm_mns = plot_vals['mean']
  pm_mns = pm_mns.view((len(pts_x), len(gammas), len(rhos))).cpu().detach() 
  pm_sds = plot_vals['precision']
  pm_sds = pm_sds.view((len(pts_x), len(gammas), len(rhos))).pow(-.5).cpu().detach() 
  
  cts_plot_vals = ppm(cts_x)
  cts_pm_mns = cts_plot_vals['mean']
  cts_pm_mns = cts_pm_mns.view((len(cts_x), len(gammas), len(rhos))).cpu().detach() 
  cts_pm_sds = cts_plot_vals['precision']
  cts_pm_sds = cts_pm_sds.view((len(cts_x), len(gammas), len(rhos))).pow(-.5).cpu().detach() 


  x_plot = x.cpu().detach().flatten()
  cts_x_plot = cts_x.cpu().detach().flatten()
  y_plot = y.cpu().detach().flatten()

  # each value of reg for the mean network
  for i in range(len(gammas)):

      # each value of reg for the prec/sd network
      for j in range(len(rhos)):
          

          if j == 0:
              for _, axs in axs_dict.items():
                  axs[i][j].set_ylabel(r"alpha: {:.2E}".format(a_labs[i][0]))

          if i == len(gammas)-1:
              for _, axs in axs_dict.items():
                  axs[i][j].set_xlabel(r"beta: {:.2E}".format(b_labs[0][j]))

          
          mns = pm_mns[:, i, j]
          sds = pm_sds[:, i, j]

          cts_mns = cts_pm_mns[:, i, j]
          cts_sds = cts_pm_sds[:, i, j]

          resids = (mns - y_plot).abs()

          axs_dict['mns'][i][j].fill_between(cts_x_plot.cpu().flatten(), (cts_mns-cts_sds).flatten(), (cts_mns+cts_sds).flatten(), color='b', alpha=.2)
          axs_dict['mns'][i][j].fill_between(cts_x_plot.cpu().flatten(), (cts_mns-2*cts_sds).flatten(), (cts_mns+2*cts_sds).flatten(), color='b', alpha=.1)

          axs_dict['mns'][i][j].scatter(x_plot, y_plot)
          axs_dict['mns'][i][j].plot(cts_x_plot, cts_mns, c='tab:orange')
          if bound_mn_y is not None:
            axs_dict['mns'][i][j].set_ylim(-bound_mn_y, bound_mn_y)
          

          axs_dict['res_sds'][i][j].scatter(x_plot, resids)
          axs_dict['res_sds'][i][j].plot(cts_x_plot, cts_sds, c='tab:orange')
          
          if plot_loss:
            axs_dict['loss'][i][j].plot(loss_grid[:, i, j])

          if bound_sd_y is not None:
            axs_dict['res_sds'][i][j].set_ylim(-.1, bound_sd_y)
          
      

      print(i)
      


  fig_res_sds.suptitle('Synthetic: Pred SDs over Residuals ' + str(iteration), size=50)
  fig_res_sds.subplots_adjust(top=0.95)


  fig_mn.suptitle('Synthetic: Means ' + str(iteration), size=50)
  fig_mn.subplots_adjust(top=0.95)

  if plot_loss:
    fig_loss.suptitle('losses ' + str(iteration), size=50)
    fig_loss.subplots_adjust(top=0.95)
  
  if path is not None: 
    fig_mn.savefig(path +'/plots/mean_' + str(iteration) + '.png')
    fig_res_sds.savefig(path +'/plots/res_sd_' + str(iteration) + '.png')
    if plot_loss:
      fig_loss.savefig(path +'/plots/loss_' + str(iteration) + '.png')
  
  if show_plots:
    plt.show()

  plt.close('all')

def make_heatmap(title, pd_df, xtick, ytick, xlab, ylab, save_path, save=True, symlognorm=True, figsize=(4, 3)):
  # plot figures
  
  plt.figure(figsize = figsize)
  plt.title(title)
  if symlognorm:
    vmin = min(pd_df.min(), 0.)
    vmax = max(pd_df.max(), 1.) 
    norm = SymLogNorm(linthresh=0.03, linscale=0.03,vmin=vmin, vmax=vmax, base=10)
   
    n_ticks = 3
  else:
    norm = None
    n_ticks = 10
    vmin = min(pd_df.min(), 0.)
    vmax = max(pd_df.max(), 1.)


  sns.heatmap(pd_df, annot=False, xticklabels=xtick, yticklabels=ytick, norm=norm, vmin=vmin, vmax=vmax, cbar_kws={'ticks':MaxNLocator(n_ticks), 'format':'%.e'})

  plt.xlabel(xlab)
  plt.ylabel(ylab)
  if save:
    plt.savefig(save_path, dpi=300)

  plt.show()
  plt.close()

def num_grad(x, gw):
  #n = (1/280) * torch.roll(x, -4) + (-4/105) * torch.roll(x, -3) + (1/5) * torch.roll(x, -2) + (-4/5) * torch.roll(x, -1) + (4/5) * torch.roll(x, 1) + (-1/5) * torch.roll(x, 2) + (4/105) * torch.roll(x, 3) + (-1/280) * torch.roll(x, 4)
  n = (-1/2) * torch.roll(x, -1) + (1/2) * torch.roll(x, 1) 

  d = gw

  return -n / d

def vec_num_grad(x, gw):
  #n = (1/280) * torch.roll(x, -4) + (-4/105) * torch.roll(x, -3) + (1/5) * torch.roll(x, -2) + (-4/5) * torch.roll(x, -1) + (4/5) * torch.roll(x, 1) + (-1/5) * torch.roll(x, 2) + (4/105) * torch.roll(x, 3) + (-1/280) * torch.roll(x, 4)
  n = (-1/2) * torch.roll(x, -1, 0) + (1/2) * torch.roll(x, 1, 0) 

  d = gw

  return -n / d

class H_SA_SGHMC(Optimizer):
    """ Stochastic Gradient Hamiltonian Monte-Carlo Sampler that uses scale adaption during burn-in
        procedure to find some hyperparamters."""

    def __init__(self, params, lr=1e-2, base_C=0.05, burn_in_period=150, momentum_sample_freq=1000, addnoise=True):
        self.eps = 1e-6
        self.burn_in_period = burn_in_period
        self.momentum_sample_freq = momentum_sample_freq
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if base_C < 0:
            raise ValueError("Invalid friction term: {}".format(base_C))

        defaults = dict(
            lr=lr,
            base_C=base_C,
            addnoise=addnoise,
        )
        super(H_SA_SGHMC, self).__init__(params, defaults)

    def step(self):
        """Simulate discretized Hamiltonian dynamics for one step"""
        loss = None

        for group in self.param_groups:  # iterate over blocks -> the ones defined in defaults. We dont use groups.
            for p in group["params"]:  # these are weight and bias matrices
                if p.grad is None:
                    continue
                state = self.state[p]  # define dict for each individual param
                if len(state) == 0:
                    state["iteration"] = 0
                    state["tau"] = torch.ones_like(p)
                    state["g"] = torch.ones_like(p)
                    state["V_hat"] = torch.ones_like(p)
                    state["v_momentum"] = torch.zeros_like(
                        p)  # p.data.new(p.data.size()).normal_(mean=0, std=np.sqrt(group["lr"])) #

                state["iteration"] += 1  # this is kind of useless now but lets keep it provisionally

                base_C, lr = group["base_C"], group["lr"]
                tau, g, V_hat = state["tau"], state["g"], state["V_hat"]

                d_p = p.grad.data

                # update parameters during burn-in
                if state["iteration"] <= self.burn_in_period: # We update g first as it makes most sense
                    tau.add_(-tau * (g ** 2) / (
                                V_hat + self.eps) + 1)  # specifies the moving average window, see Eq 9 in [1] left
                    tau_inv = 1. / (tau + self.eps)
                    g.add_(-tau_inv * g + tau_inv * d_p)  # average gradient see Eq 9 in [1] right
                    V_hat.add_(-tau_inv * V_hat + tau_inv * (d_p ** 2))  # gradient variance see Eq 8 in [1]

                V_sqrt = torch.sqrt(V_hat)
                V_inv_sqrt = 1. / (V_sqrt + self.eps)  # preconditioner

                if (state["iteration"] % self.momentum_sample_freq) == 0:  # equivalent to var = M under momentum reparametrisation
                    state["v_momentum"] = torch.normal(mean=torch.zeros_like(d_p),
                                                       std=torch.sqrt((lr ** 2) * V_inv_sqrt))
                v_momentum = state["v_momentum"]

                if group['addnoise']:
                    noise_var = (2. * (lr ** 2) * V_inv_sqrt * base_C - (lr ** 4))
                    noise_std = torch.sqrt(torch.clamp(noise_var, min=1e-16))
                    # sample random epsilon
                    noise_sample = torch.normal(mean=torch.zeros_like(d_p), std=torch.ones_like(d_p) * noise_std)
                    # update momentum (Eq 10 right in [1])
                    v_momentum.add_(- (lr ** 2) * V_inv_sqrt * d_p - base_C * v_momentum + noise_sample)
                else:
                    # update momentum (Eq 10 right in [1])
                    v_momentum.add_(- (lr ** 2) * V_inv_sqrt * d_p - base_C * v_momentum)

                # update theta (Eq 10 left in [1])
                p.data.add_(v_momentum)

        return loss

'''


def plot_result_vi(ax, x, y, true_mu, x_plot, mup, sdp):
  ax.scatter(x.squeeze().cpu().detach(), y.squeeze().cpu(), color="tab:blue")
  ax.plot(x.squeeze().cpu().detach(), true_mu.squeeze().cpu(), '--', color="black")

  #axa = ax.twinx()
  #if m is not None:
  #  ax.title('mn: ' + str(m.mean_prior_std) + '; std: ' + str(m.var_prior_std))  
  #ax.plot(x_plot.squeeze().detach(), tup[0].squeeze().detach(), color="tab:blue")
  #ax.plot(x_plot.squeeze().detach(), tup[1].squeeze().detach(), color="tab:orange")
  for mu_, std_ in zip(mup, sdp):
    ax.plot(x_plot.squeeze().cpu().detach(), mu_.squeeze().cpu().detach(), color="tab:blue")
    ax.plot(x_plot.squeeze().cpu().detach(), std_.squeeze().cpu().detach(), color="tab:orange")

def plot_sd_res(ax, x, res, x_plot, sd_plot):
  ax.scatter(x.squeeze().cpu().detach(), res.squeeze().cpu().detach(), color="tab:blue")
  
  for sd in sd_plot:
    ax.plot(x_plot.squeeze().cpu().detach(), sd.squeeze().cpu().detach(), color="tab:orange")

'''