import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
import torch.distributions as dist
import numpy as np
import matplotlib.pyplot as plt
from torch.optim import Adam
from imagedata import *
from imageCEVAE import *
import scipy
import pandas as pd
from sklearn.linear_model import LogisticRegression
import pickle
import os
import glob
import re
import functools
from collections.abc import Iterable

def estimate_imageCEVAE_ATE(model):
    """Uses Monte Carlo Integration"""
    z_dim = model.z_dim
    device = model.device
    n = 100000
    z = torch.randn(n,z_dim).to(device)
    py_dot1 = torch.sigmoid(model.decoder.y1_nn(z)).mean()
    py_dot0 = torch.sigmoid(model.decoder.y0_nn(z)).mean()
    ATE = py_dot1 - py_dot0
    return ATE, py_dot1, py_dot0

def best_estimate_ate(z,t,y):
    """Returns the best ATE and p(y|do(t)) that one could estimate if one new the true z and the generating process"""
    df = pd.DataFrame(torch.cat([z,t,y],1).detach().numpy(), columns=['z{}'.format(i) for i in range(3)] + ['t', 'y'])
    logreg_t1 = LogisticRegression(penalty='none')
    logreg_t0 = LogisticRegression(penalty='none')
    logreg_t1.fit(X=df[df.t==1].iloc[:,:1],y=df[df.t==1]['y'])
    logreg_t0.fit(X=df[df.t==0].iloc[:,:1],y=df[df.t==0]['y'])
    z_sample = np.random.randn(1000000,1)
    p_y_dot1_best = logreg_t1.predict_proba(z_sample)[:,1].mean()
    p_y_dot0_best = logreg_t0.predict_proba(z_sample)[:,1].mean()
    return p_y_dot1_best-p_y_dot0_best, p_y_dot1_best, p_y_dot0_best

def viz_image_space(generator, dim1=0, dim2=1, gendim=3):
    """"""
    with torch.no_grad():
        n = 10
        z = gendim*[None]
        z[dim1] = torch.linspace(-2.5,2.5,n)[:,None].repeat(n,1)
        z[dim2] = torch.linspace(2.5,-2.5,n)[:,None].repeat_interleave(n,0)
        for i in range(gendim):
            if i != dim1 and i != dim2:
                z[i] = torch.zeros(n**2,1)
        z = torch.cat(z,1)[:,:,None,None]
        out = generator(z).squeeze().cpu().detach().numpy()
    out_grid = np.zeros((28*n,28*n))
    grid_vals = np.linspace(-1,1,n)
    for i in range(n**2):
        x = (i % n)
        y = (i // n)
        out_grid[y*28:(y+1)*28,x*28:(x+1)*28] = out[i]
    plt.figure()
    plt.xlabel("Dim {}".format(dim1))
    plt.ylabel("Dim {}".format(dim2))
    plt.imshow(out_grid, extent=(-2.5, 2.5, -2.5, 2.5))
    plt.show()
    
def viz_image_space_1D(generator):
    with torch.no_grad():
        n = 10
        z = torch.linspace(-2.5,2.5,n)[:,None]
        out = generator(z).squeeze().cpu().detach().numpy()
    out_grid = np.zeros((28,28*n))
    for i in range(n):
        out_grid[:,i*28:(i+1)*28] = out[i]
    plt.figure()
    plt.imshow(out_grid, extent=(-2.5, 2.5, -2.5/10, 2.5/10))
    
def viz_other_space(generator, dim1=0, dim2=1, gendim=3):
    n = 100
    z_range = np.linspace(-2.5,2.5,n)
    z0,z1 = np.meshgrid(z_range,np.array(list(reversed(z_range))))
    z2 = np.zeros((n,n,1))
    z = gendim*[None]
    z[dim1] = z0[:,:,None]
    z[dim2] = z1[:,:,None]
    for i in range(gendim):
        if i != dim1 and i != dim2:
            z[i] = torch.zeros(n,n,1)
    z = np.concatenate(z,2).reshape(-1,gendim)
    z = torch.Tensor(z)
    y_pred = generator(z)
    plt.figure()
    plt.imshow(y_pred.reshape(100,100).cpu().detach(), extent = (z_range[0],z_range[-1],z_range[0],z_range[-1]))
    plt.xlabel("Dim {}".format(dim1))
    plt.ylabel("Dim {}".format(dim2))
    plt.show()
    
def estimate_image_ty_MI(model,n_sample=100, n_z=100):
    #Estimates the mutual information between the image data and (t,y) data generated by the VAE
    
    #Create an artificial sample
    z_sample = torch.randn(n_sample,model.z_dim).to(model.device)
    temp = model.decoder.ct1(model.decoder.lin(z_sample)[:,:,None,None])
    temp = model.decoder.ct2(temp)
    temp = model.decoder.ct3(temp)
    image_pred_1 = model.decoder.ct4(temp)
    image_sample = dist.Bernoulli(logits=image_pred_1).sample()
    t_pred_1 = model.decoder.t_nn(z_sample)
    t_sample = dist.Bernoulli(logits=t_pred_1).sample()
    y_logits0 = model.decoder.y0_nn(z_sample)
    y_logits1 = model.decoder.y1_nn(z_sample)
    y_pred_1 = y_logits1*t_sample + y_logits0*(1-t_sample)
    y_sample = dist.Bernoulli(logits=y_pred_1).sample()

    #Calculate log probabilities of the sample according to the model
    #Create a sample of z to integrate over
    z = torch.randn(n_z,model.z_dim).to(model.device)
    temp = model.decoder.ct1(model.decoder.lin(z)[:,:,None,None])
    temp = model.decoder.ct2(temp)
    temp = model.decoder.ct3(temp)
    image_pred_2 = model.decoder.ct4(temp)
    t_pred_2 = model.decoder.t_nn(z)
    y_pred_2_t0 = model.decoder.y0_nn(z)
    y_pred_2_t1 = model.decoder.y1_nn(z)

    #log(p(img,t,y)), log(p(img)) and log(p(t,y)) for each of the n_sample (x,t,y) simulations
    imgty_log_prob = np.zeros(n_sample)
    ty_log_prob = np.zeros(n_sample)
    img_log_prob = np.zeros(n_sample)

    for i in range(n_sample):
        image_log_probs = dist.Bernoulli(logits=image_pred_2).log_prob(image_sample[i]).sum([1,2,3])
        t_log_probs = dist.Bernoulli(logits=t_pred_2).log_prob(t_sample[i]).squeeze()
        if t_sample[i]==1:
            y_log_probs = dist.Bernoulli(logits=y_pred_2_t1).log_prob(y_sample[i]).squeeze()
        else:
            y_log_probs = dist.Bernoulli(logits=y_pred_2_t0).log_prob(y_sample[i]).squeeze()
        imgty_log_probs = image_log_probs + t_log_probs + y_log_probs
        ty_log_probs = t_log_probs + y_log_probs
        imgty_log_prob[i] = functools.reduce(np.logaddexp, imgty_log_probs.detach().cpu()) - np.log(n_z)#MC integration, mean
        ty_log_prob[i] = functools.reduce(np.logaddexp, ty_log_probs.detach().cpu()) - np.log(n_z)
        img_log_prob[i] = functools.reduce(np.logaddexp, image_log_probs.detach().cpu()) - np.log(n_z)

    MI = (imgty_log_prob - img_log_prob - ty_log_prob).mean()
    return MI

def kld_loss(mu, std):
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    #Note that this corresponds to mean loss functions
    var = std.pow(2)
    kld = -0.5 * torch.sum(1 + torch.log(var) - mu.pow(2) - var, 1).mean()
    return kld

def run_epoch(model, optimizer, train_loader, loss_scaling, epoch, device):
    epoch_loss = 0
    epoch_kld_loss = 0
    epoch_image_loss = 0
    epoch_x_loss = 0
    epoch_t_loss = 0
    epoch_y_loss = 0
    if epoch%20==0:
        print("Epoch {}:".format(epoch))
    for data in train_loader:
        image = data['image'].to(device)
        x = data['X'].to(device)
        t = data['t'].to(device)
        y = data['y'].to(device)
        image_mean, image_std, z_mean, z_std, x_pred, x_std, t_pred, y_pred, y_std = model(image,x,t)
        kld = kld_loss(z_mean, z_std)
        #image_loss = -dist.Normal(loc=image_mean, scale = image_std).log_prob(image).sum()
        image_loss = -dist.Bernoulli(logits=image_mean).log_prob(image).mean(0).sum()*loss_scaling
        x_loss = -dist.Normal(loc=x_pred, scale = x_std).log_prob(x).mean(0).sum()
        if model.decoder.p_t_z_nn:
            t_loss = -dist.Bernoulli(logits=t_pred).log_prob(t).mean(0).sum()
        else:
            t_loss = torch.tensor(0).to(device)
        y_loss = -dist.Normal(loc=y_pred, scale = y_std).log_prob(y).mean(0).sum()
        loss = kld + image_loss + x_loss + t_loss + y_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_kld_loss += kld.item()
        epoch_image_loss += image_loss.item()
        epoch_x_loss += x_loss.item()
        epoch_t_loss += t_loss.item()
        epoch_y_loss += y_loss.item()

        del loss, image_mean, image_std, z_mean, z_std, x_pred, x_std, t_pred, y_pred, image, x, t, y
    if epoch%20==0:
        print("Image: {}, x: {}, t: {}, y: {}".format(epoch_image_loss, epoch_x_loss,epoch_t_loss,epoch_y_loss))
    torch.cuda.empty_cache()
    return epoch_loss, epoch_kld_loss, epoch_image_loss, epoch_x_loss, epoch_t_loss, epoch_y_loss

def train_model(device, plot_curves, print_logs,
              train_loader, num_epochs, lr_start, lr_end, x_dim, z_dim,
              p_y_zt_nn=False, p_y_zt_nn_layers=3, p_y_zt_nn_width=10, 
              p_t_z_nn=False, p_t_z_nn_layers=3, p_t_z_nn_width=10,
              p_x_z_nn=False, p_x_z_nn_layers=3, p_x_z_nn_width=10, loss_scaling=1, separate_ty=False):
    print(loss_scaling)
    while True:
        model = ImageCEVAE(x_dim, z_dim, device=device, p_y_zt_nn=p_y_zt_nn, p_y_zt_nn_layers=p_y_zt_nn_layers,
            p_y_zt_nn_width=p_y_zt_nn_width, p_t_z_nn=p_t_z_nn, p_t_z_nn_layers=p_t_z_nn_layers, p_t_z_nn_width=p_t_z_nn_width,
            p_x_z_nn=p_x_z_nn, p_x_z_nn_layers=p_x_z_nn_layers, p_x_z_nn_width=p_x_z_nn_width,separate_ty=separate_ty)
        optimizer = Adam(model.parameters(), lr=lr_start)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = (lr_end/lr_start)**(1/num_epochs))

        losses = {"total": [], "kld": [], "image": [], "x": [], "t": [], "y": []}
        py_dot1s = []
        py_dot0s = []
        
        for epoch in range(num_epochs):
            i = 0
            
            epoch_loss, epoch_kld_loss, epoch_image_loss, epoch_x_loss, epoch_t_loss, epoch_y_loss = run_epoch(model, optimizer, train_loader, loss_scaling, epoch, device)
            
            losses['total'].append(epoch_loss)
            losses['kld'].append(epoch_kld_loss)
            losses['image'].append(epoch_image_loss)
            losses['x'].append(epoch_x_loss)
            losses['t'].append(epoch_t_loss)
            losses['y'].append(epoch_y_loss)

            scheduler.step()

            """if not (losses['total'][-1] < 1e10):
                print("Exploded!")
                break"""
            
            #ate,pydot1,pydot0 = estimate_imageCEVAE_ATE(model)
            #py_dot1s.append(pydot1.item())
            #py_dot0s.append(pydot0.item())
            
            #if epoch%20 == 0:
           #     print(ate.item(),pydot1.item(),pydot0.item())
            
            if print_logs and epoch%50==0 :
            #    print("Estimated ATE {}, p(y=1|do(t=1)): {}, p(y=1|do(t=0)): {}".format(ate,pydot1,pydot0))
                print("Epoch loss at epoch {}: {}".format(epoch,epoch_loss))
                print("Image: {}, x: {}, t: {}, y: {}".format(epoch_image_loss, epoch_x_loss,epoch_t_loss,epoch_y_loss))
                print()
                if epoch_x_loss > 1e8:
                    generator = lambda z: model.decoder.x_nn(z.to(device))
                    viz_other_space(generator, dim1=0, dim2=0, gendim=z_dim)
                    print("x_std: ", x_std)

            if plot_curves and z_dim > 1 and epoch%50==0:
                fig,ax = plt.subplots(2,5, figsize=(15,5))
                with torch.no_grad():
                    z = torch.randn(10,z_dim).to(device)
                    t = torch.zeros(z.shape[0],1).to(device)
                    out, _, _, _, _, _,_ = model.decoder(z, t)
                    out = torch.sigmoid(out).squeeze().cpu().detach().numpy()
                    for i in range(10):
                        x = i % 5
                        y = i // 5
                        ax[y][x].imshow(out[i])
                        ax[y][x].set_title("{:.2f},{:.2f}".format(z[i,0],z[i,1]))
                plt.show()
        if epoch == num_epochs - 1:#check if training didn't stop too early
            break
    
    fig, ax = plt.subplots(2,3,figsize=(12,8))
    ax[0,0].plot(losses['image'])
    ax[0,1].plot(losses['t'])
    ax[0,2].plot(losses['x'])
    ax[1,0].plot(losses['y'])
    ax[1,1].plot([loss for loss in losses['kld'] if loss < 1e4])
    ax[1,2].plot([loss for loss in losses['total'] if loss < 1e4])
    ax[0,0].set_title("image loss")
    ax[0,1].set_title("t loss")
    ax[0,2].set_title("x loss")
    ax[1,0].set_title("y loss")
    ax[1,1].set_title("kld loss")
    ax[1,2].set_title("total loss")
    plt.show()
    plt.plot()
    plt.title("Loss at end of each epoch")
    plt.show()
    
    return model, losses#, py_dot1s, py_dot0s

def train_decoder(device, model, print_logs, train_loader, num_epochs, lr_start, lr_end):
    """train the decoder from the source domain"""
    losses = {"total": [], "image": [], "x": [], "t": [], "y": []}
    optimizer = Adam(model.decoder.parameters(), lr=lr_start)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = (lr_end/lr_start)**(1/num_epochs))
    for epoch in range(num_epochs):
        epoch_loss = 0
        epoch_image_loss = 0
        epoch_x_loss = 0
        epoch_t_loss = 0
        epoch_y_loss = 0
        for data in train_loader:
            image = data['image'].to(device)
            x = data['X'].to(device)
            t = data['t'].to(device)
            y = data['y'].to(device)
            z = data['z'].to(device)
            #with torch.no_grad():
            #    z_mean,z_std = model.encoder(image,x,t,y)
            #    z = model.reparameterize(z_mean, z_std)
            image_mean, image_std, x_pred, x_std, t_pred, y_pred, y_std = model.decoder(z,t)
            #image_loss = -dist.Normal(loc=image_mean, scale = image_std).log_prob(image).sum()
            image_loss = -dist.Bernoulli(logits=image_mean).log_prob(image).mean(0).sum()
            x_loss = -dist.Normal(loc=x_pred, scale = x_std).log_prob(x).mean(0).sum()
            if model.decoder.p_t_z_nn:
                t_loss = -dist.Bernoulli(logits=t_pred).log_prob(t).mean(0).sum()
            else:
                t_loss = torch.tensor(0).to(device)
            y_loss = -dist.Normal(loc=y_pred, scale = y_std).log_prob(y).mean(0).sum()
            loss = image_loss + x_loss + y_loss + t_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_image_loss += image_loss.item()
            epoch_x_loss += x_loss.item()
            epoch_t_loss += t_loss.item()
            epoch_y_loss += y_loss.item()
        scheduler.step()
        losses["total"].append(epoch_loss)
        losses["image"].append(epoch_image_loss)
        losses["x"].append(epoch_x_loss)
        losses["y"].append(epoch_y_loss)
        losses["t"].append(epoch_t_loss)
        if print_logs and epoch%50==0:
            #print("Estimated ATE {}, p(y=1|do(t=1)): {}, p(y=1|do(t=0)): {}".format(*estimate_imageCEVAE_ATE(model)))
            print("Epoch {}:".format(epoch))
            print("Image: {}, x: {}, t: {}, y: {}".format(epoch_image_loss, epoch_x_loss,epoch_t_loss,epoch_y_loss))
            print()

    fig, ax = plt.subplots(2,3,figsize=(12,8))
    ax[0,0].plot(losses['image'])
    ax[0,1].plot(losses['t'])
    ax[0,2].plot(losses['x'])
    ax[1,0].plot(losses['y'])
    ax[1,2].plot([loss for loss in losses['total'] if loss < 1e4])
    ax[0,0].set_title("image loss")
    ax[0,1].set_title("t loss")
    ax[0,2].set_title("x loss")
    ax[1,0].set_title("y loss")
    ax[1,2].set_title("total loss")
    plt.show()
    plt.plot()
    plt.title("Loss at end of each epoch")
    plt.show()
    return model.decoder

def trainconvynet(device, num_epochs, dataset, BATCH_SIZE, lr_start, lr_end):
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
    model = ConvyNet(device=device)
    optimizer = Adam(model.parameters(), lr=lr_start)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = (lr_end/lr_start)**(1/num_epochs))
    epoch_losses = []
    for epoch in range(num_epochs):
        if epoch%5 == 0:
            print("Epoch {}".format(epoch))
        epoch_losses.append(0)
        for batch in dataloader:
            image = batch['image'].to(device)
            x = batch['X'].to(device)
            t = batch['t'].to(device)
            y = batch['y'].to(device)
            ypred = model(image, x, t)
            loss = -dist.Bernoulli(logits=ypred).log_prob(y).mean(0).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_losses[-1] += loss.item()
        scheduler.step()
    plt.plot(epoch_losses)
    return model
            
def expand_parameters(params, iterated):
    """Helper function to get the elements in params to be lists of len(iterated)"""
    new_params = len(params)*[None]
    for i in range(len(params)):
        if not isinstance(params[i], list):
            new_params[i] = len(iterated)*[params[i]]#dim (len(train_arguments), len(iterated))
        else:
            assert len(params[i]) == len(iterated)
            new_params[i] = params[i].copy()
    return new_params

def create_or_empty_folder(main_folder,sub_folder):
    try:
        os.mkdir("./data/{}/".format(main_folder))
    except OSError:
        pass
    try:
        os.mkdir("./data/{}/{}/".format(main_folder,sub_folder))
    except OSError:
        print("Creation of the directory './data/{}/{}/ failed. Trying to empty the same folder.".format(main_folder,sub_folder))
        files = glob.glob('./data/{}/{}/*'.format(main_folder, sub_folder))
        for f in files:
            os.remove(f)

def save_dataparameters(dataparameters, main_folder, sub_folder):
    create_or_empty_folder(main_folder,sub_folder)
    with open("./data/{}/{}/params".format(main_folder,sub_folder), "wb") as file:
        pickle.dump(dataparameters, file)

def load_dataparameters(main_folder, sub_folder):
    with open("./data/{}/{}/params".format(main_folder,sub_folder), "rb") as file:
        return pickle.load(file)

def create_dfs_datasets(generate_data, dataparameters, param_times, repeat, main_folder, sub_folder, labels):
    #dataparameters has to be a list of lists, param_times is how many times we use one data parameter combination
    #repeat is a boolean that tells whether we should use the same data
    create_or_empty_folder(main_folder,sub_folder)
    
    dfs = {label: {} for label in labels}
    datasets = {label: {} for label in labels}
    for i,data_params in enumerate(dataparameters):
        print("Step ", i)
        if repeat:
            z, images, x, t, y, dataset = generate_data(*data_params)
            df = pd.DataFrame(torch.cat([z,x,t,y],1).detach().numpy().squeeze(),columns=["z"+str(i) for i in range(z.shape[1])] + ["x0","t","y"])
            #SAVE RESULTS
            with open("./data/{}/{}/df_{}".format(main_folder, sub_folder,labels[i]), "wb") as file:
                pickle.dump(df, file)
            with open("./data/{}/{}/dataset_{}".format(main_folder, sub_folder,labels[i]), "wb") as file:
                pickle.dump(dataset, file)
            for j in range(param_times):
                dfs[labels[i]][j] = df
                datasets[labels[i]][j] = dataset
        else:
            for j in range(param_times):
                z, images, x, t, y, dataset = generate_data(*data_params)
                df = pd.DataFrame(torch.cat([z,x,t,y],1).detach().numpy().squeeze(),columns=["z"+str(i) for i in range(z.shape[1])] + ["x0","t","y"])
                #SAVE RESULTS
                with open("./data/{}/{}/df_{}_{}".format(main_folder, sub_folder,labels[i],j), "wb") as file:
                    pickle.dump(df, file)
                with open("./data/{}/{}/dataset_{}_{}".format(main_folder, sub_folder,labels[i],j), "wb") as file:
                    pickle.dump(dataset, file)
                dfs[labels[i]][j] = df
                datasets[labels[i]][j] = dataset
    return dfs, datasets

def load_dfs_datasets(main_folder, sub_folder, param_times=None):
    dfs = {}
    datasets = {}
    for filename in os.listdir("data/{}/{}/".format(main_folder, sub_folder)):
        match = re.search(r"df_([^_]*)_(\d*)", filename)
        if match:
            if not match.group(1) in dfs:
                dfs[match.group(1)] = {}
            with open("data/{}/{}/{}".format(main_folder,sub_folder,filename), "rb") as file:
                dfs[match.group(1)][int(match.group(2))] = pickle.load(file)
        else:
            match = re.search(r"df_([^_]*)", filename)
            if match:
                with open("data/{}/{}/{}".format(main_folder,sub_folder,filename), "rb") as file:
                    dfs[match.group(1)] = {}
                    df =  pickle.load(file)
                    for i in range(param_times):
                        dfs[match.group(1)][i] = df
        match = re.search(r"dataset_([^_]*)_(\d*)", filename)
        if match:
            if not match.group(1) in datasets:
                datasets[match.group(1)] = {}
            with open("data/{}/{}/{}".format(main_folder,sub_folder,filename), "rb") as file:
                datasets[match.group(1)][int(match.group(2))] = pickle.load(file)
        else:
            match = re.search(r"dataset_([^_]*)", filename)
            if match:
                with open("data/{}/{}/{}".format(main_folder,sub_folder,filename), "rb") as file:
                    datasets[match.group(1)] = {}
                    dataset = pickle.load(file)
                    for i in range(param_times):
                        datasets[match.group(1)][i] = dataset
    return dfs, datasets

def run_model_for_predef_datasets(datasets, param_times, main_folder, sub_folder, BATCH_SIZE, track_function, true_value,
                                  device, train_arguments, labels, data_labels, overwrite=True):
    #Main folder organizes related experiments with same/similar data. Sub-folder has the results from this experiment
    #datasets can be different data for each label or the same data repeated many times, however we want
    if overwrite:
        create_or_empty_folder(main_folder,sub_folder)
    
    train_arguments = expand_parameters(train_arguments, labels)
    train_arguments = list(map(list,zip(*train_arguments))) #dim (len(iterated, len(train_arguments))
    
    models = {label: {} for label in labels}
    losses = {label: {} for label in labels}
    pydot1s = {label: {} for label in labels}
    pydot0s = {label: {} for label in labels}
    
    for i in range(len(labels)):
        for j in range(param_times):
            dataloader = DataLoader(datasets[data_labels[i]][j], batch_size=BATCH_SIZE)
            #Running the model
            model, loss, pydot1, pydot0 = train_model(device, False, False, dataloader, *train_arguments[i])
            torch.save(model.state_dict(), "./data/{}/{}/model_{}_{}".format(main_folder,sub_folder,labels[i],j))
            with open("./data/{}/{}/loss_{}_{}".format(main_folder,sub_folder,labels[i],j), "wb") as file:
                pickle.dump(loss, file)
            with open("./data/{}/{}/pydot1_{}_{}".format(main_folder,sub_folder,labels[i],j), "wb") as file:
                pickle.dump(pydot1, file)
            with open("./data/{}/{}/pydot0_{}_{}".format(main_folder,sub_folder,labels[i],j), "wb") as file:
                pickle.dump(pydot0, file)
            print("Estimated causal effect: {} true value: {}".format(track_function(model), true_value))
            models[labels[i]][j] = model
            losses[labels[i]][j] = loss
            pydot1s[labels[i]][j] = pydot1
            pydot0s[labels[i]][j] = pydot0
    
    return models, losses, pydot1s, pydot0s

def load_models_losses(main_folder, sub_folder, train_arguments, labels, device):
    train_arguments = expand_parameters(train_arguments, labels)
    train_arguments = list(map(list, zip(*train_arguments)))
    #We see only the labels in the folder, but we want the indices for accessing other arguments (train_arguments)
    labels_to_index = dict(zip(map(str,labels), range(len(labels))))
    models = {}
    losses = {}
    for file in os.listdir("data/{}/{}/".format(main_folder, sub_folder)):
        match = re.search(r"([^_]*)_([^_]*)_(\d*)", file)
        if match.group(1) == "model":
            index = labels_to_index[match.group(2)]
            num_epochs, lr_start, lr_end, x_dim, z_dim, p_y_zt_nn, p_y_zt_nn_layers, p_y_zt_nn_width, p_t_z_nn, p_t_z_nn_layers, p_t_z_nn_width, p_x_z_nn, p_x_z_nn_layers, p_x_z_nn_width, loss_scaling, separate_ty = train_arguments[index]
            model = ImageCEVAE(x_dim, z_dim, device=device, p_y_zt_nn=p_y_zt_nn, p_y_zt_nn_layers=p_y_zt_nn_layers,
                        p_y_zt_nn_width=p_y_zt_nn_width, p_t_z_nn=p_t_z_nn, p_t_z_nn_layers=p_t_z_nn_layers, 
                        p_t_z_nn_width=p_t_z_nn_width, p_x_z_nn=p_x_z_nn, 
                        p_x_z_nn_layers=p_x_z_nn_layers, p_x_z_nn_width=p_x_z_nn_width, separate_ty=separate_ty)
            model.load_state_dict(torch.load("data/{}/{}/{}".format(main_folder, sub_folder,file)))
            model.eval()
            if not match.group(2) in models:
                models[match.group(2)] = {int(match.group(3)): model}
            else:
                models[match.group(2)][int(match.group(3))] = model
        elif match.group(1) == "loss":
            with open("data/{}/{}/{}".format(main_folder, sub_folder, file), "rb") as file:
                if not match.group(2) in losses:
                    losses[match.group(2)] = {}
                losses[match.group(2)][int(match.group(3))] = pickle.load(file)
    return models, losses
        
def load_pydots(main_folder, sub_folder, labels, device):
    labels_to_index = dict(zip(map(str,labels), range(len(labels))))
    pydot1s = {}
    pydot0s = {}
    for file in os.listdir("data/{}/{}/".format(main_folder, sub_folder)):
        match = re.search(r"([^_]*)_([^_]*)_(\d*)", file)
        if match.group(1) == "pydot1":
            with open("data/{}/{}/{}".format(main_folder, sub_folder, file), "rb") as file:
                if not match.group(2) in pydot1s:
                    pydot1s[match.group(2)] = {}
                pydot1s[match.group(2)][int(match.group(3))] = pickle.load(file)
        elif match.group(1) == "pydot0":
            with open("data/{}/{}/{}".format(main_folder, sub_folder, file), "rb") as file:
                if not match.group(2) in pydot0s:
                    pydot0s[match.group(2)] = {}
                pydot0s[match.group(2)][int(match.group(3))] = pickle.load(file)
    return pydot1s, pydot0s

def run_model_for_data_sets(datasize, param_times,
                            folder, name, 
                            BATCH_SIZE, generate_data, dataparameters, track_function, true_value,
                            device, train_arguments, labels, 
                            post_decoder_training=False, post_decoder_arguments=[], loss_scaling=1,
                           share_data_between_runs=False):
    """train_arguments is a list with the following:
    num_epochs, lr_start, lr_end, x_dim, z_dim,
    p_y_zt_nn, p_y_zt_nn_layers, p_y_zt_nn_width, 
    p_t_z_nn, p_t_z_nn_layers, p_t_z_nn_width,
    p_x_z_nn, p_x_z_nn_layers, p_x_z_nn_width"""
    """Runs the model for a parameter sweep. Saves the results in data/{folder}.
    Currently just empties everything in the folder before starting on new stuff.
    Idea: Some of the arguments in train_arguments are datasize is lists, and 
    we iterate through those and save the results. 'iterated' is the list object which names 
    the results
    if post_decoder_training=True then does extra training for the decoder only. post_decoder_arguments consists of:
    num_epochs, lr_start, lr_end"""
    try:
        os.mkdir("data/{}/".format(folder))
    except OSError:
        print("Creation of the directory data/{}/ failed. Trying to empty the same folder.".format(folder))
        files = glob.glob('data/{}/*'.format(folder))
        for f in files:
            os.remove(f)
    assert not (isinstance(datasize,Iterable) and share_data_between_runs)#ensure that share_data_between_runs makes sense
    
    datasize = expand_parameters([datasize], labels)[0]
    train_arguments = expand_parameters(train_arguments, labels)
    train_arguments = list(map(list, zip(*train_arguments))) #dim (len(iterated, len(train_arguments))
    loss_scaling = expand_parameters([loss_scaling],labels)[0]
    
    datas = {label: {} for label in labels}
    models = {label: {} for label in labels}
    losses = {label: {} for label in labels}
    if share_data_between_runs:
        z, images, x, t, y, dataset = generate_data(datasize[0], *dataparameters)
    for i in range(len(labels)):
        for j in range(param_times):
            num_samples = datasize[i]
            print("Training data size {}, run {}".format(num_samples, j+1))
            if not share_data_between_runs:
                z, images, x, t, y, dataset = generate_data(num_samples, *dataparameters)
            dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)
            #Running the model
            model, loss = train_model('cuda', False, False, dataloader, *train_arguments[i], loss_scaling[i])
            if post_decoder_training:
                train_decoder(device, model, print_logs, train_loader, *post_decoder_arguments)
            
            data = (z, images, x, t, y, dataset)
            #datas[labels[i]][j] = data
            #models[labels[i]][j] = model
            #losses[labels[i]][j] = loss

            torch.save(model.state_dict(), "./data/{}/model_{}_{}_{}".format(folder,name,labels[i],j))
            file = open("something.pkl", "wb")
            with open("./data/{}/data_{}_{}_{}".format(folder,name,labels[i],j), "wb") as file:
                pickle.dump(data, file)
            with open("./data/{}/loss_{}_{}_{}".format(folder,name,labels[i],j), "wb") as file:
                pickle.dump(loss, file)
            print("Estimated causal effect: {} true value: {}".format(track_function(model), true_value))
            
    return datas, models, losses


def load_dfs_models(folder, name, train_arguments, datasize, labels, device):
    """Loads dataframes and trained models from data/{folder}/ that match the experiment name"""
    datasize = expand_parameters([datasize], labels)
    train_arguments = expand_parameters(train_arguments, labels)
    train_arguments = list(map(list, zip(*train_arguments)))
    #We see only the labels in the folder, but we want the indices for accessing other arguments (train_arguments)
    labels_to_index = dict(zip(map(str,labels), range(len(labels))))
    
    datas = {}
    models = {}
    losses = {}
    for file in os.listdir("data/{}".format(folder)):
        #Group 1 data/model/loss identifier, group 2 is the name (unnecessary), group 3 is the experiment setup
        #and group 4 is the number of the try
        match = re.search(r"([^_]*)_([^_]*)_([^_]*)_(\d*)", file)
        if match is not None:
            if match.group(2) == name:
                if match.group(1) == "data":
                    if not match.group(3) in datas:
                        with open("data/{}/{}".format(folder,file), "rb") as file:
                            datas[match.group(3)] = {int(match.group(4)): pickle.load(file)}
                    else:
                        with open("data/{}/{}".format(folder,file), "rb") as file:
                            datas[match.group(3)][int(match.group(4))] = pickle.load(file)
                elif match.group(1) == "loss":
                    if not match.group(3) in losses:
                        with open("data/{}/{}".format(folder,file), "rb") as file:
                            losses[match.group(3)] = {int(match.group(4)): pickle.load(file)}
                    else:
                        with open("data/{}/{}".format(folder,file), "rb") as file:
                            losses[match.group(3)][int(match.group(4))] = pickle.load(file)
                elif match.group(1) == "model":
                    index = labels_to_index[match.group(3)]
                    num_epochs, lr_start, lr_end, x_dim, z_dim, p_y_zt_nn, p_y_zt_nn_layers, p_y_zt_nn_width, p_t_z_nn, p_t_z_nn_layers, p_t_z_nn_width, p_x_z_nn, p_x_z_nn_layers, p_x_z_nn_width = train_arguments[index]
                    model = ImageCEVAE(x_dim, z_dim, device=device, p_y_zt_nn=p_y_zt_nn, p_y_zt_nn_layers=p_y_zt_nn_layers,
                        p_y_zt_nn_width=p_y_zt_nn_width, p_t_z_nn=p_t_z_nn, p_t_z_nn_layers=p_t_z_nn_layers, 
                        p_t_z_nn_width=p_t_z_nn_width, p_x_z_nn=p_x_z_nn, 
                        p_x_z_nn_layers=p_x_z_nn_layers, p_x_z_nn_width=p_x_z_nn_width)
                    model.load_state_dict(torch.load("data/{}/{}".format(folder,file)))
                    model.eval()
                    if not match.group(3) in models:
                        models[match.group(3)] = {int(match.group(4)): model}
                    else:
                        models[match.group(3)][int(match.group(4))] = model
    return datas, models, losses