import copy
import pickle
import sys
import time
import random
import os
import numpy as np
import pandas as pd

import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
#tf.config.set_per_process_memory_growth(True)

# tf.debugging.set_log_device_placement(True)

from keras.layers import Input, Dense, Conv2D, LeakyReLU, Dropout, Flatten, MaxPooling2D, GlobalAveragePooling2D
from tqdm import tqdm

from softTrainXrayGan.csmodels import xrayGenerator, xrayDiscriminator
from trainXrayEncoder.ConvolutionalCondVAE import Encoder, Decoder, ConvCVAE
from softTrainXrayGan.tfFunctionsUtils import apply_gumbel_softmax, map_fill_to_discrete, compare_conditionals_within
from softTrainXrayGan.tfFunctionsUtils import get_joint_distributions_from_samples, penalty_calculation
from sklearn.preprocessing import OneHotEncoder
from softTrainXrayGan.tfFunctionsUtils import calculate_TVD
from softTrainXrayGan.csxray_graph import Experiment
from softTrainXrayGan.tfFunctionsUtils import getdoKey
from softTrainXrayGan.tfFunctionsUtils import load_dataset
from softTrainXrayGan.csxray_graph import set_Xray
from softTrainXrayGan.tfFunctionsUtils import calculate_KL

# Did not feed the input again like the original model. Check the main code to update
def covidGen(noise_dim, gen_dim):
    inp_noise = Input(shape=[noise_dim, ], name='cov_noise')
    X = keras.layers.concatenate([inp_noise], axis=1)

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.BatchNormalization()(X)
    X= tf.keras.layers.ReLU()(X)

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.BatchNormalization()(X)
    X= tf.keras.layers.ReLU()(X)

    last = Dense(gen_dim)(X)

    return tf.keras.Model(inputs=[inp_noise], outputs=last, name='covidGen')


def PneumGen(condition_dim, noise_dim, gen_dim):
    inp_condition = Input(shape=[condition_dim, ], name='condition_G')
    inp_noise = Input(shape=[noise_dim, ], name='pneum_noise')

    X = keras.layers.concatenate([inp_condition, inp_noise], axis=1)

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.BatchNormalization()(X)
    X= tf.keras.layers.ReLU()(X)

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.BatchNormalization()(X)
    X= tf.keras.layers.ReLU()(X)

    last = Dense(gen_dim)(X)

    return tf.keras.Model(inputs=[inp_condition, inp_noise], outputs=last, name='PneumGen')


def RxrayModel(gen_dim):
    inp_target = tf.keras.layers.Input(shape=[gen_dim, ], name='targetRxray')
    X=inp_target

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.LeakyReLU(alpha=0.2)(X)
    X= tf.keras.layers.Dropout(0.5)(X)

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.LeakyReLU(alpha=0.2)(X)
    X= tf.keras.layers.Dropout(0.5)(X)

    last = Dense(256)(X)
    return tf.keras.Model(inputs=[inp_target], outputs=last, name='Rxray_disc')


def Discriminator(gen_dim):
    inp_target = tf.keras.layers.Input(shape=[gen_dim, ], name='target')
    X=inp_target

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.LeakyReLU(alpha=0.2)(X)
    X= tf.keras.layers.Dropout(0.5)(X)

    X = Dense(256, activation=None)(X)
    X= tf.keras.layers.LeakyReLU(alpha=0.2)(X)
    X= tf.keras.layers.Dropout(0.5)(X)

    last = Dense(1)(X)
    return tf.keras.Model(inputs=[inp_target], outputs=last, name='Discriminator')




def get_generators(Exp, load_which_models):
    label_generators = {}
    optimizersMech = {}


    for label in Exp.Observed_DAG:
        noise_dims = Exp.CONF_NOISE_DIM

        if label=='covid_19':
            gen_dim=2
            label_generators[label] = covidGen(noise_dims, gen_dim)
            optimizersMech[label] =tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)

        if label=='xray':
            label_generators[label] = xrayGenerator(latent_dim=100, n_classes=3)
            epoch = 288
            filename = f"/SaveDir/params_generator_epoch_{epoch}.hdf5"
            label_generators[label].load_weights(filename)
            label_generators[label].trainable= False
            print("xray layers are freezed")

        if label=='pneum':
            disc = xrayDiscriminator(input_shape=(Exp.IMAGE_SIZE, Exp.IMAGE_SIZE, 3))
            epoch = 288
            filename = f"/SaveDir/params_discriminator_epoch_{epoch}.hdf5"
            disc.load_weights(filename)
            disc.trainable= False
            new_model= disc

            condition_dim=3
            gen_dim=2
            pneum_model = PneumGen(condition_dim, noise_dims, gen_dim)
            label_generators[label] = [new_model, pneum_model]
            optimizersMech[label] =tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)


        if label=='Rxray':
            latent_dim = 128
            beta = 0.65

            encoder = Encoder(latent_dim)
            decoder = Decoder()
            model = ConvCVAE(
                encoder,
                decoder,
                label_dim=1,
                latent_dim=latent_dim,
                beta=beta,
                image_dim=[64,64, 3])

            # Checkpoint path
            checkpoint_root = "./CVAE{}_{}_checkpoint".format(latent_dim, beta)
            checkpoint_name = "model"
            save_prefix = os.path.join(checkpoint_root, checkpoint_name)

            # Define the checkpoint
            checkpoint = tf.train.Checkpoint(module=model)

            ###
            # Restore the latest checkpoint
            latest = tf.train.latest_checkpoint(checkpoint_root)
            if latest is not None:
                checkpoint.restore(latest)
                print("Checkpoint restored:", latest)
            else:
                print("No checkpoint!")

            trunc_disc = model
            trunc_disc.trainable= False
            label_generators[label] = trunc_disc   # 64x 64x 3 -> 129


    return label_generators, optimizersMech


def get_discriminators(Exp):
    discriminatorsMech={}
    doptimizersMech={}


    rep_dim= 129 # latent dim =128 + label =1
    compare_dims= Exp.label_dim['covid_19'] + Exp.label_dim['pneum'] +  rep_dim
    discriminatorsMech['H2'] = Discriminator(compare_dims)
    doptimizersMech['H2'] = tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)

    compare_dims=2
    discriminatorsMech['covid_19'] = Discriminator(compare_dims)
    doptimizersMech['covid_19'] = tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)

    compare_dims=2
    discriminatorsMech['pneum'] = Discriminator(compare_dims)
    doptimizersMech['pneum'] = tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)


    compare_dims = 4
    discriminatorsMech['low_joint'] = Discriminator(compare_dims)
    doptimizersMech['low_joint'] = tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)

    return discriminatorsMech, doptimizersMech


def get_generated_labels(Exp, label_generators, intervened, batch_size, data_batch=[]):  #no recursion. uses the same noise

    gen_labels={}
    confNoises = tf.random.normal([batch_size, Exp.NOISE_DIM], mean=0.0, stddev=1.0,
                             dtype=tf.dtypes.float32)

    if 'covid_19' in intervened.keys():
        zeroes = tf.zeros([batch_size, Exp.label_dim['covid_19']])
        ones = tf.ones([batch_size, Exp.label_dim['covid_19']])

        if intervened['covid_19']==0:
            gen_labels['covid_19']= tf.concat(axis=1, values=[ones, zeroes])
        else:
            gen_labels['covid_19']= tf.concat(axis=1, values=[zeroes, ones])

    else:
        output = label_generators['covid_19']([confNoises], training=True)
        soft, hard= apply_gumbel_softmax(output, Exp.Temperature)
        gen_labels['covid_19']= soft

    #******xray starts
    #(1-C)(1+N) ; N\in[0,1]

    # (1-C)(1.5+TN) where T is discrete variable in {-0.5, 0.5}. N is a narrow Gaussian. T can also be in {-0.3,0.3}
    #equivalent: i0*(1.5 + TN) or (i0)*(1.5 + T+N)

    i0= tf.reshape(gen_labels['covid_19'][:,0], [-1,1])
    if len(data_batch)==0: # when no pneum data, use random
        T= random.choices( [-0.5, 0.5], weights=[0.5,0.5], k=batch_size)
        T= tf.reshape(T, [-1,1])
        N= tf.random.normal(shape=(batch_size,1), mean=0.0, stddev=0.0001, dtype=tf.float32)
        img_par= (i0)*(1.5 + T+N)  #makes more sense to me.
    else: #using real pneum data
        data_batch =data_batch[0]
        n1= tf.reshape(data_batch[:,1], [-1,1])  #pneum =1 index
        n_ones= tf.ones([batch_size,1])
        n1= tf.cast(n1, tf.dtypes.float32)
        n_ones= tf.cast(n_ones, tf.dtypes.float32)
        img_par = (i0) * (n1+n_ones)  #i0* (n1+1)

    image_noise = np.random.uniform(-1, 1, (batch_size, 100))
    generated_images_batch = label_generators['xray']([image_noise, img_par])
    gen_labels['xray'] = generated_images_batch

    # Rxray image  #takes image & label as input and outputs 129


    resized_images = tf.image.resize(gen_labels['xray'], (64, 64))

    input_img, input_label, conditional_input = label_generators['Rxray'].conditional_input([resized_images, img_par])
    encoded = label_generators['Rxray'].encoder(conditional_input, label_generators['Rxray'].latent_dim, is_train=False)
    z_mean, z_log_var = tf.split(encoded, num_or_size_splits=2, axis=1)
    z_cond = label_generators['Rxray'].reparametrization(z_mean, z_log_var, input_label)
    gen_labels['Rxray']= z_cond

    # pneum
    model1= label_generators['pneum'][0]
    model2= label_generators['pneum'][1]
    fk_rl, class3_labels = model1([gen_labels['xray']])  #getting image label from pre-trained model
    output = model2([class3_labels, confNoises], training=True)

    soft, hard= apply_gumbel_softmax(output, Exp.Temperature)
    gen_labels['pneum']= soft


    return gen_labels



def calculate_joint(Exp, keep_G_fake):
    covid19 = keep_G_fake[:,0:2]
    print(covid19.shape)
    covid19= tf.math.argmax(covid19, axis=1)
    covid19= tf.reshape(covid19, [-1,1])

    pneum= keep_G_fake[:,2:4]
    print(pneum.shape)
    pneum= tf.math.argmax(pneum, axis=1)
    pneum= tf.reshape(pneum, [-1,1])

    joint= tf.concat(axis=1, values= [tf.cast(covid19, tf.int32) , tf.cast(pneum, tf.int32)  ])

    joint_prob= get_joint_distributions_from_samples(['covid_19','pneum'], [2,2], joint.numpy())
    covid_prob= get_joint_distributions_from_samples(['covid_19'], [2], covid19.numpy())
    pneum_prob= get_joint_distributions_from_samples(['pneum'], [2], pneum.numpy())


    # P(pneum|covid)
    cond_prob_list = compare_conditionals_within(Exp, joint.numpy(), ['pneum'], ['covid_19'], ['covid_19', 'pneum'])

    return joint_prob, cond_prob_list, covid_prob, pneum_prob

def train_D(Exp,  label_generators, discriminators, D_optimizer, data_batch, image_batch):

    print('Training Discriminator')

    ids=[]
    for ii, x in enumerate(data_batch):
        if x[0]==1:
            ids.append(ii)

    drows = int(len(ids) * 0.1)  # get 10 samples
    selected_ids= random.sample(ids, drows)

    covid_rows= tf.reshape(tf.gather(data_batch[:,0], selected_ids), [-1,1])
    zerooos= tf.zeros([drows,1])
    covid_rows= tf.cast(covid_rows, tf.int64)
    zerooos= tf.cast(zerooos, tf.int64)

    dummy_data= tf.concat(axis=1, values=[covid_rows,zerooos ])
    dummy_images= tf.gather(image_batch, selected_ids)

    data_batch= tf.concat(axis=0, values=[data_batch, dummy_data])
    image_batch= tf.concat(axis=0, values=[image_batch, dummy_images])


    enc = OneHotEncoder()
    enc.fit(data_batch)
    data_batch = enc.transform(data_batch).toarray()
    #
    resized_images = tf.image.resize(image_batch, (64, 64))
    input_img, input_label, conditional_input = label_generators['Rxray'].conditional_input([resized_images, tf.reshape(data_batch[:,0], [-1,1])])
    encoded = label_generators['Rxray'].encoder(conditional_input, label_generators['Rxray'].latent_dim, is_train=False)
    z_mean, z_log_var = tf.split(encoded, num_or_size_splits=2, axis=1)
    z_cond = label_generators['Rxray'].reparametrization(z_mean, z_log_var, input_label)
    encoded_real_image= z_cond
    databatch_wimg= tf.concat([data_batch, encoded_real_image], 1)


    # ****
    G_fake = get_generated_labels(Exp, label_generators, {}, data_batch.shape[0])
    fake_batch = tf.concat(axis=1,values=[G_fake['covid_19'], G_fake['pneum']])
    encoded_fake_image = G_fake['Rxray']
    fakebatch_wimg = tf.concat(axis=1,values=[fake_batch, encoded_fake_image])


    # ****
    with tf.GradientTape() as disc_tape:
        D_real = discriminators['H2']([databatch_wimg], training=True)
        D_fake = discriminators['H2']([fakebatch_wimg], training=True)
        penalty = penalty_calculation(discriminators['H2'], databatch_wimg, fakebatch_wimg)
        D_loss =  tf.reduce_mean(D_fake - D_real + Exp.LAMBDA_GP * penalty)
    gradients_of_discriminator = disc_tape.gradient(D_loss, discriminators['H2'].trainable_variables)
    D_optimizer['H2'].apply_gradients(zip(gradients_of_discriminator, discriminators['H2'].trainable_variables))


    # with tf.GradientTape() as low_joint_tape:
    #     D_real = discriminators['low_joint']([data_batch], training=True)
    #     D_fake = discriminators['low_joint']([fake_batch], training=True)
    #     penalty = penalty_calculation(discriminators['low_joint'], data_batch, fake_batch)
    #     D_loss = tf.reduce_mean(D_fake - D_real + Exp.LAMBDA_GP * penalty)
    # gradients_of_discriminator = low_joint_tape.gradient(D_loss, discriminators['low_joint'].trainable_variables)
    # D_optimizer['low_joint'].apply_gradients(zip(gradients_of_discriminator, discriminators['low_joint'].trainable_variables))



    return D_loss
def train_G(Exp,  label_generators, G_optimizers, discriminators, data_batch):

    print('Training Generator')

    # ************
    with tf.GradientTape() as gen_tape:

        G_fake = get_generated_labels(Exp, label_generators, {}, data_batch.shape[0])
        fake_batch = tf.concat(axis=1,values=[G_fake['covid_19'], G_fake['pneum']])
        encoded_fake_image = G_fake['Rxray']
        fakebatch_wimg = tf.concat(axis=1,values=[fake_batch, encoded_fake_image])

        # for P(C, Rxray, Pn)
        D_fake = discriminators['H2']([fakebatch_wimg], training=True)
        l3=  -tf.reduce_mean(D_fake)


        #P(Covid, Pneum)
        # D_fake = discriminators['low_joint']([fake_batch], training=True)
        # l4=  -tf.reduce_mean(D_fake)

        # G_loss = l3 + l1 + l2 + l4
        G_loss = l3

        print(f'P(C,Rxray, Pneum) G_loss--->  {G_loss}')


    grad1, grad2 = gen_tape.gradient(G_loss, [label_generators['covid_19'].trainable_variables,
                                                         label_generators['pneum'][1].trainable_variables])
                                                         # label_generators['Rxray'][1].trainable_variables])

    G_optimizers['covid_19'].apply_gradients(zip(grad1, label_generators['covid_19'].trainable_variables))
    G_optimizers['pneum'].apply_gradients(zip(grad2, label_generators['pneum'][1].trainable_variables))

    return G_loss


def do_train(Exp,  label_generators, discriminators,G_optimizers, D_optimizer, data_batch, image_batch):

    enc = OneHotEncoder()
    enc.fit(data_batch)
    data_batch = enc.transform(data_batch).toarray()

    # -----------------------------------------------------------------------------
    print('Training Generator')
    with tf.GradientTape() as gen_tape:

        G_fake = get_generated_labels(Exp, label_generators, {}, data_batch.shape[0], [data_batch[:,2:4]]) # send only pneumonia
        fake_batch = tf.concat(axis=1, values=[G_fake['covid_19'], G_fake['pneum']])
        encoded_fake_image = G_fake['Rxray']
        fakebatch_wimg = tf.concat(axis=1, values=[fake_batch, encoded_fake_image])

        # for P(C, Rxray, Pn)
        D_fake = discriminators['H2']([fakebatch_wimg], training=True)
        l3 = -tf.reduce_mean(D_fake)

        # #P(Covid)
        D_fake = discriminators['covid_19']([fake_batch[:,0:2]], training=True)
        l1=  -tf.reduce_mean(D_fake)

        G_loss = l3 + l1

        print(f'P(C,Rxray, Pneum)+P(C) : G_loss--->  {G_loss}')

    grad1, grad2 = gen_tape.gradient(G_loss, [label_generators['covid_19'].trainable_variables, label_generators['pneum'][1].trainable_variables])
    G_optimizers['covid_19'].apply_gradients(zip(grad1, label_generators['covid_19'].trainable_variables))
    G_optimizers['pneum'].apply_gradients(zip(grad2, label_generators['pneum'][1].trainable_variables))

    # -----------------------------------------------------------------------------
    print('Training Discriminator')

    #
    resized_images = tf.image.resize(image_batch, (64, 64))
    input_img, input_label, conditional_input = label_generators['Rxray'].conditional_input([resized_images, tf.reshape(data_batch[:,0], [-1,1])])
    encoded = label_generators['Rxray'].encoder(conditional_input, label_generators['Rxray'].latent_dim, is_train=False)
    z_mean, z_log_var = tf.split(encoded, num_or_size_splits=2, axis=1)
    z_cond = label_generators['Rxray'].reparametrization(z_mean, z_log_var, input_label)
    encoded_real_image= z_cond
    databatch_wimg= tf.concat([data_batch, encoded_real_image], 1)


    # ****
    G_fake = get_generated_labels(Exp, label_generators, {}, data_batch.shape[0], [data_batch[:,2:4]])
    fake_batch = tf.concat(axis=1,values=[G_fake['covid_19'], G_fake['pneum']])
    encoded_fake_image = G_fake['Rxray']
    fakebatch_wimg = tf.concat(axis=1,values=[fake_batch, encoded_fake_image])


    # ****
    with tf.GradientTape() as disc_tape:
        D_real = discriminators['H2']([databatch_wimg], training=True)
        D_fake = discriminators['H2']([fakebatch_wimg], training=True)
        penalty = penalty_calculation(discriminators['H2'], databatch_wimg, fakebatch_wimg)
        D_loss =  tf.reduce_mean(D_fake - D_real + Exp.LAMBDA_GP * penalty)
    gradients_of_discriminator = disc_tape.gradient(D_loss, discriminators['H2'].trainable_variables)
    D_optimizer['H2'].apply_gradients(zip(gradients_of_discriminator, discriminators['H2'].trainable_variables))

    with tf.GradientTape() as covid_tape:
        D_real = discriminators['covid_19']([data_batch[:,0:2]], training=True)
        D_fake = discriminators['covid_19']([fake_batch[:, 0:2]], training=True)
        penalty = penalty_calculation(discriminators['covid_19'], data_batch[:,0:2], fake_batch[:, 0:2])
        D_loss = tf.reduce_mean(D_fake - D_real + Exp.LAMBDA_GP * penalty)
    gradients_of_discriminator = covid_tape.gradient(D_loss, discriminators['covid_19'].trainable_variables)
    D_optimizer['covid_19'].apply_gradients(zip(gradients_of_discriminator, discriminators['covid_19'].trainable_variables))

    return G_loss, D_loss





def trainloop(Exp, cur_hnodes, label_generators, G_optimizers, discriminators, D_optimizers, train_dataset):
    iteration=0
    for img_batch, label_batch in zip(train_dataset['img'], train_dataset['labels']):
        batch1= tf.reshape(label_batch['covid_19'], [-1,1])
        batch2 = tf.reshape(label_batch['pneumonia'], [-1, 1])
        udata_batch = tf.concat(axis=1, values=[batch1, batch2])

        image_batch= img_batch   #normalized when loaded the dataset
        image_batch = tf.image.per_image_standardization(image_batch)

        # G_loss = train_G(Exp,  label_generators, G_optimizers, discriminators, udata_batch)
        # for _ in range(Exp.CRITIC_ITERATIONS):
        #     D_loss = train_D(Exp,  label_generators, discriminators,
        #                                         D_optimizers, udata_batch, image_batch)
        G_loss, D_loss= do_train(Exp, label_generators, discriminators, G_optimizers, D_optimizers, udata_batch, image_batch)

        print('Epoch [%d/%d], Step [%d/%d],' % (
            Exp.curr_epoochs + 1, Exp.num_epochs, iteration + 1, len(train_dataset['labels'])),
              'mechanism: ', cur_hnodes, ' D_loss: %.4f, G_loss: %.4f' % (D_loss.numpy(), G_loss.numpy()))

        print('Reduced temperature:',Exp.Temperature)

        iteration+=1



    tot_iter = Exp.curr_epoochs * len(train_dataset) + iteration
    # if tot_iter % 1 == 0:
    Exp.anneal_temperature(tot_iter)

    print("--->", Exp.curr_epoochs)
    if Exp.curr_epoochs%1==0:
        test_size=1000
        G_fake = get_generated_labels(Exp, label_generators, {}, test_size)
        fake_batch = tf.concat(axis=1, values=[G_fake['covid_19'], G_fake['pneum']])
        f_joint_prob, f_cond_list, f_covid_prob, f_pneum_prob= calculate_joint(Exp, fake_batch)

        enc = OneHotEncoder()
        enc.fit(udata_batch)
        data_batch = enc.transform(udata_batch).toarray()
        r_joint_prob, r_cond_list, r_covid_prob, r_pneum_prob= calculate_joint(Exp, data_batch)

        # P(covid,pneum)
        print(f'Real prob:joint_prob:{r_joint_prob}')
        print(f'Fake prob: joint_prob:{f_joint_prob}')
        obs_tvd = calculate_TVD(f_joint_prob, r_joint_prob, doPrint=False)
        Exp.tvd_diff['joint'].append(round(obs_tvd, 4))

        # P(pneum|covid)
        print(f'Real prob: P(pneum|covid=0) :{r_cond_list[0]}')
        print(f'Fake prob: P(pneum|covid=0):{f_cond_list[0]}')
        obs_tvd = calculate_TVD(f_cond_list[0], r_cond_list[0], doPrint=False)
        Exp.tvd_diff['cond_cov0'].append(round(obs_tvd, 4))

        print(f'Real prob:P(pneum|covid=1):{r_cond_list[1]}')
        print(f'Fake prob: P(pneum|covid=1):{f_cond_list[1]}')
        obs_tvd = calculate_TVD(f_cond_list[1], r_cond_list[1], doPrint=False)
        Exp.tvd_diff['cond_cov1'].append(round(obs_tvd, 4))

        # P(covid)
        print(f'Real prob:covid_prob:{r_covid_prob}')
        print(f'Fake prob: covid_prob:{f_covid_prob}')
        obs_tvd = calculate_TVD(f_covid_prob, r_covid_prob, doPrint=False)
        Exp.tvd_diff['covid'].append(round(obs_tvd, 4))

        # P(Pneum)
        print(f'Real prob:pneum_prob:{r_pneum_prob}')
        print(f'Fake prob: pneum_prob:{f_pneum_prob}')
        obs_tvd = calculate_TVD(f_pneum_prob, r_pneum_prob, doPrint=False)
        Exp.tvd_diff['pneum'].append(round(obs_tvd, 4))


        # ATE
        # intervention  do(covid=1)
        G_fake = get_generated_labels(Exp, label_generators, {'covid_19': 1}, test_size)
        intv_pneum = G_fake['pneum']
        intv_pneum = tf.math.argmax(intv_pneum, axis=1)
        intv_pneum = tf.reshape(intv_pneum, [-1, 1])
        intv_pneum_prob_do1 = get_joint_distributions_from_samples(['pneum'], [2], intv_pneum.numpy())
        print('P(Y|do(X=1)', intv_pneum_prob_do1)

        # intervention  do(covid=0)
        G_fake = get_generated_labels(Exp, label_generators, {'covid_19': 0}, test_size)
        intv_pneum = G_fake['pneum']
        intv_pneum = tf.math.argmax(intv_pneum, axis=1)
        intv_pneum = tf.reshape(intv_pneum, [-1, 1])
        intv_pneum_prob_do0 = get_joint_distributions_from_samples(['pneum'], [2], intv_pneum.numpy())
        print('P(Y|do(X=0)', intv_pneum_prob_do0)

        ATE = intv_pneum_prob_do1[tuple([1])] - intv_pneum_prob_do0[tuple([1])]
        Exp.tvd_diff['ATE'].append(ATE)


        ll = -min(10, len(list(Exp.tvd_diff.values())[0]))
        for dist in Exp.tvd_diff:
            print("###", dist, " loss%:",  [round(val, 4) for val in Exp.tvd_diff[dist][ll:]])


        # path = ".././SaveDir/tvd"
        path = f"/SaveDir/{Exp.exp_name}/tvd"
        os.makedirs(path, exist_ok=True)

        for dist in Exp.tvd_diff:
            np.save(f'{path}/{dist}.npy', np.array(Exp.tvd_diff[dist]))

        print('files saved')


        if Exp.curr_epoochs>200 and Exp.tvd_diff['joint'][-1]<0.20:
        # if Exp.curr_epoochs % 1 == 0:
            root= "/projectroot"
            label_generators['covid_19'].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}/covid_19_gen/gen')
            label_generators['pneum'][1].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}/pneumonia_gen/gen')
            # label_generators['Rxray'][1].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}/Rxray_gen')

            discriminators['H2'].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}/joint_disc/disc')
            discriminators['covid_19'].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}/covid_19_disc/disc')
            discriminators['pneum'].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}/pneumonia_disc/disc')
            discriminators['low_joint'].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}/low_joint_disc/disc')
            # discriminators['Rxray'].save_weights(f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{Exp.tvd_diff["joint"][-1]}/Rxray_disc')

            print('model saved!!!')


            # intervention  do(covid=1)
            G_fake = get_generated_labels(Exp, label_generators, {'covid_19': 1}, test_size)
            intv_pneum = G_fake['pneum']
            intv_pneum = tf.math.argmax(intv_pneum, axis=1)
            intv_pneum = tf.reshape(intv_pneum, [-1, 1])
            intv_pneum_prob_do1 = get_joint_distributions_from_samples(['pneum'], [2], intv_pneum.numpy())


            os.makedirs(f'{path}/intv1_pneum', exist_ok=True)
            with open(f'{path}/intv1_pneum/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}.pkl', 'wb') as f:
                pickle.dump(intv_pneum_prob_do1, f)

            #intervention  do(covid=0)
            G_fake = get_generated_labels(Exp, label_generators, {'covid_19': 0}, test_size)
            intv_pneum = G_fake['pneum']
            intv_pneum = tf.math.argmax(intv_pneum, axis=1)
            intv_pneum = tf.reshape(intv_pneum, [-1, 1])
            intv_pneum_prob_do0 = get_joint_distributions_from_samples(['pneum'], [2], intv_pneum.numpy())


            os.makedirs(f'{path}/intv0_pneum', exist_ok=True)
            with open(f'{path}/intv0_pneum/Epoch{Exp.curr_epoochs}_{ Exp.tvd_diff["joint"][-1]}.pkl', 'wb') as f:
                pickle.dump(intv_pneum_prob_do0, f)


if __name__ == '__main__':

    args= sys.argv

    if len(args)==1:
        exp_name='xrayEncodertest'
    else:
        exp_name= args[1]

    Exp = Experiment(set_Xray,
                        exp_name= exp_name,
                         NOISE_DIM=64,
                         CONF_NOISE_DIM= 64,
                         Temperature=1,
                          temp_min=0.01,
                          ANNEAL_RATE=0.003,  #
                          CRITIC_ITERATIONS=1,
                          LAMBDA_GP=10,
                         batch_size=200,  #
                         ENCODED_DIM=128,
                         Data_intervs=[{}],
                         num_epochs=1000,
                         IMAGE_SIZE=112,
                         new_experiment=True
                         )

    # Do rejection sampling if necessary to test. P(Y|D)


    print('Here Experiment name:', Exp.exp_name)
    Exp.tvd_diff= {'joint':[], 'cond_cov0':[],  'cond_cov1':[], 'covid':[], 'pneum':[], 'ATE':[] }
    dag_name = Exp.Complete_DAG_desc + ".txt"


    root="/Dataset/COVIDx-splitted-resized-112"
    data = pd.read_csv(f'{root}/train_dataset.csv')

    ##### normalized Image data load [-1,1]
    image_data, valid_id = load_dataset(Exp.batch_size, Exp.IMAGE_SIZE, root, data, split='train')

    ##### label data load
    label_data = data[["covid_19", "pneumonia"]].iloc[valid_id]

    # replacingn 400 rows with covid=1 , pneum=1 with covid=1, pneum=0  keeping the same image
    idlist= label_data.index[label_data['covid_19'] == 1].tolist()
    ret= random.sample(idlist,1000)
    label_data.loc[ret, ["pneumonia"]] = 0
    #

    label_dataset = tf.data.Dataset.from_tensor_slices(dict(label_data)).batch(Exp.batch_size)
    train_dataset={'img':image_data, 'labels':label_dataset}

    # learning rate
    initial_learning_rate = 5*1e-4
    final_learning_rate = 1e-4
    learning_rate_decay_factor = (final_learning_rate / initial_learning_rate) ** (1 / Exp.num_epochs)
    steps_per_epoch = int(len(valid_id) / Exp.batch_size)  #dataset size/batch size
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=initial_learning_rate,
        decay_steps=steps_per_epoch,
        decay_rate=learning_rate_decay_factor,
        staircase=True)
    Exp.learning_rate= lr_schedule

    # Models
    cur_hnodes = {"H2": ["covid_19", "pneum"]}
    label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)
    discriminatorsMech, doptimizersMech = get_discriminators(Exp)  #

    for epoch in tqdm(range(Exp.num_epochs)):
        Exp.curr_epoochs = epoch
        trainloop(Exp, cur_hnodes, label_generators, optimizersMech, discriminatorsMech, doptimizersMech, train_dataset)




