"""Module containing training functions for the various models evaluated in the DECAF paper."""

import os
import pickle

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from sklearn.neural_network import MLPClassifier

from data import DataModule
from models.DECAF import DECAF
import tensorflow as tf
from models.FairGAN import Medgan

models_dir = 'cache'



def train_decaf(train_dataset, dag_seed, biased_edges={}, h_dim=200, lr=0.5e-3,
                batch_size=64, lambda_privacy=0, lambda_gp=10, d_updates=10,
                alpha=2, rho=2, weight_decay=1e-2, grad_dag_loss=False, l1_g=0,
                l1_W=1e-4, p_gen=-1, use_mask=True, epochs=50, model_name='decaf'):
    model_filename = os.path.join(models_dir, f'{model_name}.pkl')

    dm = DataModule(train_dataset.values)

    model = DECAF(
        dm.dims[0],
        dag_seed=dag_seed,
        h_dim=h_dim,
        lr=lr,
        batch_size=batch_size,
        lambda_privacy=lambda_privacy,
        lambda_gp=lambda_gp,
        d_updates=d_updates,
        alpha=alpha,
        rho=rho,
        weight_decay=weight_decay,
        grad_dag_loss=grad_dag_loss,
        l1_g=l1_g,
        l1_W=l1_W,
        p_gen=p_gen,
        use_mask=use_mask,
    )

    if os.path.exists(model_filename):
        model = torch.load(model_filename)
    else:
        trainer = pl.Trainer(max_epochs=epochs, logger=False)
        trainer.fit(model, dm)
        torch.save(model, model_filename)

    # Generate synthetic data
    synth_dataset = (
        model.gen_synthetic(
            dm.dataset.x,
            gen_order=model.get_gen_order(),
            biased_edges=biased_edges,
        )
        .detach()
        .numpy()
    )
    synth_dataset[:, -1] = synth_dataset[:, -1].astype(np.int8)

    synth_dataset = pd.DataFrame(synth_dataset,
                                 index=train_dataset.index,
                                 columns=train_dataset.columns)

    if 'approved' in synth_dataset.columns:
        # Binarise columns for credit dataset
        synth_dataset['ethnicity'] = np.round(synth_dataset['ethnicity'])
        synth_dataset['approved'] = np.round(synth_dataset['approved'])
    else:
        # Binarise columns for adult dataset
        synth_dataset['sex'] = np.round(synth_dataset['sex'])
        synth_dataset['income'] = np.round(synth_dataset['income'])

    return synth_dataset

def train_fairgan(train_dataset, embedding_dim=128, random_dim=128,
                  generator_dims=(128, 128), discriminator_dims=(128, 128),
                  bn_decay=0.99, l2_scale=0.001, batch_size=100,
                  pretrain_epochs=50, train_epochs=10, model_name='fairgan'):
    tf.compat.v1.disable_eager_execution()

    data = train_dataset.values

    data_filename = os.path.join(models_dir, 'adult.npy')

    with open(data_filename, 'wb') as data_file:
        pickle.dump(data, data_file)

    inputDim = data.shape[1]-1
    inputNum = data.shape[0]
    tf.compat.v1.reset_default_graph()
    mg = Medgan(dataType='count',
                inputDim=inputDim,
                embeddingDim=embedding_dim,
                randomDim=random_dim,
                generatorDims=generator_dims,
                discriminatorDims=discriminator_dims,
                compressDims=(),
                decompressDims=(),
                bnDecay=bn_decay,
                l2scale=l2_scale)

    out_file = os.path.join(models_dir, model_name)

    if not os.path.exists(out_file + '.meta'):
        mg.train(dataPath=data_filename,
                modelPath='',
                outPath=out_file,
                pretrainEpochs=pretrain_epochs,
                nEpochs=train_epochs,
                discriminatorTrainPeriod=2,
                generatorTrainPeriod=1,
                pretrainBatchSize=batch_size,
                batchSize=batch_size,
                saveMaxKeep=0)

    tf.compat.v1.reset_default_graph()
    synth_data =  mg.generateData(nSamples=inputNum,
                                  modelFile=out_file,
                                  batchSize=batch_size,
                                  outFile=out_file)

    mlp = MLPClassifier()
    # print(synth_data.shape[0])
    X_train, y_train = train_dataset.drop(columns=['income']), train_dataset['income']
    mlp.fit(X_train, y_train)
    income = mlp.predict(synth_data)
    # print(len(income))
    synth_data = np.append(synth_data, income.reshape((len(income), 1)), axis=1)

    return pd.DataFrame(synth_data,
                        columns=train_dataset.columns)
