#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import os
import utils.args_parser  as argtools
import pytorch_lightning as pl
import numpy as np


# # LOAD CONFIG

# In[ ]:


use_custom_dataset = True


# ### Option 1: Datasets from the paper

# In[ ]:


if not use_custom_dataset:
    print('Using dataset from the paper')
    dataset_file =  os.path.join('_params', 'dataset_adult.yaml')
    model_file =   os.path.join('_params', 'model_vaca.yaml')
    trainer_file =   os.path.join('_params', 'trainer.yaml')

    yaml_file = ''

    if yaml_file == '':
        cfg = argtools.parse_args(dataset_file)
        cfg.update(argtools.parse_args(model_file))
        cfg.update(argtools.parse_args(trainer_file))
    else:
        cfg = argtools.parse_args(yaml_file)


# ### Option 2: New dataset

# In[ ]:


if use_custom_dataset:
    print('Using custom dataset')
    model_file =   os.path.join('_params', 'model_vaca.yaml')
    trainer_file =   os.path.join('_params', 'trainer.yaml')

    yaml_file = ''
    if yaml_file == '':
        cfg = argtools.parse_args(model_file)
        cfg.update(argtools.parse_args(trainer_file))
    else:
        cfg = argtools.parse_args(yaml_file)



# In[ ]:


# Config for new dataset

cfg['dataset'] = {
    'name': '2nodes',
    'params1': {},
    'params2': {}
}

cfg['dataset']['params1'] = {
    'data_dir': '../Data',
    'batch_size': 1000,
    'num_workers': 0
}

cfg['dataset']['params2'] = {
    'num_samples_tr': 5000,
    'equations_type': 'linear',
    'normalize': 'lik',
    'likelihood_names': 'd',
    'lambda_': 0.05,
    'normalize_A': None,
}


# ### You can also update any parameter manually

# In[ ]:


cfg['root_dir'] = 'results'
cfg['seed'] = 0
pl.seed_everything(cfg['seed'])

cfg['dataset']['params'] = cfg['dataset']['params1'].copy()
cfg['dataset']['params'].update(cfg['dataset']['params2'])

cfg['dataset']['params']['data_dir'] = ''

cfg['trainer']['limit_train_batches'] = 1.0
cfg['trainer']['limit_val_batches'] = 1.0
cfg['trainer']['limit_test_batches'] = 1.0
cfg['trainer']['check_val_every_n_epoch'] = 10


def print_if_not_dict(key, value, extra=''):
    if not isinstance(value, dict):
        print(f"{extra}{key}: {value}")
        return True
    else:
        print(f"{extra}{key}:")
        False

for key, value in cfg.items():
    if not print_if_not_dict(key, value):
        for key2, value2 in value.items():
            if not print_if_not_dict(key2, value2, extra='\t'):
                for key3, value3 in value2.items():
                    print_if_not_dict(key3, value3, extra='\t\t')


# # LOAD DATASET

# In[ ]:


from utils.constants import Cte


print('These are datasets from the paper:')
for data_name in Cte.DATASET_LIST:
    print(f"\t{data_name}")



print(f"\nUsing dataset: {cfg['dataset']['name']}")


# In[ ]:


if cfg['dataset']['name'] in Cte.DATASET_LIST:
    from data_modules.het_scm import HeterogeneousSCMDataModule

    dataset_params = cfg['dataset']['params'].copy()
    dataset_params['dataset_name'] = cfg['dataset']['name']

    data_module = HeterogeneousSCMDataModule(**dataset_params)
    data_module.prepare_data()

elif cfg['dataset']['name']  == '2nodes':
    from data_modules.my_toy_scm import MyToySCMDataModule
    from utils.distributions import *

    dataset_params = cfg['dataset']['params'].copy()
    dataset_params['dataset_name'] = cfg['dataset']['name']

    dataset_params['nodes_to_intervene'] = ['x1']
    dataset_params['structural_eq'] = {'x1': lambda u1: u1,
                                            'x2': lambda u2, x1: u2 + x1}
    dataset_params['noises_distr'] = {'x1': Normal(0,1),
                                           'x2': Normal(0,1)}
    dataset_params['adj_edges'] = {'x1': ['x2'],
                                        'x2': []}

    data_module = MyToySCMDataModule(**dataset_params)
    data_module.prepare_data()


# # LOAD MODEL

# In[ ]:


model = None
model_params = cfg['model']['params'].copy()

print(f"\nUsing model: {cfg['model']['name']}")


# In[ ]:


# VACA
if cfg['model']['name'] == Cte.VACA:
    from models.vaca.vaca import VACA

    model_params['is_heterogeneous'] = data_module.is_heterogeneous
    model_params['likelihood_x'] = data_module.likelihood_list

    model_params['deg'] = data_module.get_deg(indegree=True)
    model_params['num_nodes'] = data_module.num_nodes
    model_params['edge_dim'] = data_module.edge_dimension
    model_params['scaler'] = data_module.scaler

    model = VACA(**model_params)
    model.set_random_train_sampler(data_module.get_random_train_sampler())
# VACA with PIWAE
elif cfg['model']['name'] == Cte.VACA_PIWAE:
    from models.vaca.vaca_piwae import VACA_PIWAE

    model_params['is_heterogeneous'] = data_module.is_heterogeneous

    model_params['likelihood_x'] = data_module.likelihood_list

    model_params['deg'] = data_module.get_deg(indegree=True)
    model_params['num_nodes'] = data_module.num_nodes
    model_params['edge_dim'] = data_module.edge_dimension
    model_params['scaler'] = data_module.scaler

    model = VACA_PIWAE(**model_params)
    model.set_random_train_sampler(data_module.get_random_train_sampler())



# MultiCVAE
elif cfg['model']['name'] == Cte.MCVAE:
    from models.multicvae.multicvae import MCVAE

    model_params['likelihood_x'] = data_module.likelihood_list

    model_params['topological_node_dims'] = data_module.train_dataset.get_node_columns_in_X()
    model_params['topological_parents'] = data_module.topological_parents
    model_params['scaler'] = data_module.scaler
    model_params['num_epochs_per_nodes'] = int(
        np.floor((cfg['trainer']['max_epochs'] / len(data_module.topological_nodes))))
    model = MCVAE(**model_params)
    model.set_random_train_sampler(data_module.get_random_train_sampler())
    cfg['early_stopping'] = False

# CAREFL
elif cfg['model']['name'] == Cte.CARELF:
    from models.carefl.carefl import CAREFL

    model_params['node_per_dimension_list'] = data_module.train_dataset.node_per_dimension_list
    model_params['scaler'] = data_module.scaler
    model = CAREFL(**model_params)


# In[ ]:


model.summarize()
model.set_optim_params(optim_params=cfg['optimizer'],
                       sched_params=cfg['scheduler'])


# # LOAD EVALUATOR

# In[ ]:


from models._evaluator import MyEvaluator

evaluator = MyEvaluator(model=model,
                        intervention_list=data_module.train_dataset.get_intervention_list(),
                        scaler=data_module.scaler
                        )
model.set_my_evaluator(evaluator=evaluator)


# In[ ]:


for intervention in data_module.train_dataset.get_intervention_list():
    inter_dict, name = intervention
    print(f'Interventiona name: {name}')
    for node_name, value in inter_dict.items():
        print(f"\t{node_name}: {value}")


# # PREPARE TRAINING

# In[ ]:


from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger


is_training = False
load = True

print(f'Is training activated? {is_training}')
print(f'Is loading activated? {load}')


# In[ ]:


if yaml_file == '':
    if (cfg['dataset']['name'] in [Cte.GERMAN]) and (cfg['dataset']['params3']['train_kfold'] == True):
        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],
                                               argtools.get_experiment_folder(cfg),
                                               str(cfg['seed']), str(cfg['dataset']['params3']['kfold_idx'])))
    else:
        save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],
                                               argtools.get_experiment_folder(cfg),
                                               str(cfg['seed'])))
else:
    save_dir = os.path.join(*yaml_file.split('/')[:-1])
print(f'Save dir: {save_dir}')


# In[ ]:


logger = TensorBoardLogger(save_dir=save_dir, name='logs', default_hp_metric=False)

out = logger.log_hyperparams(argtools.flatten_cfg(cfg))

save_dir_ckpt = argtools.mkdir(os.path.join(save_dir, 'ckpt'))
if load:
    ckpt_file = argtools.newest(save_dir_ckpt)
else:
    ckpt_file = None
callbacks = []

print(f"ckpt_file: {ckpt_file}")


# In[ ]:


if is_training:
    checkpoint = ModelCheckpoint(period=1,
                                 monitor=model.monitor(),
                                 mode=model.monitor_mode(),
                                 save_top_k=1,
                                 save_last=True,
                                 filename='checkpoint-{epoch:02d}',
                                 dirpath=save_dir_ckpt)
    callbacks = [checkpoint]


    if cfg['early_stopping']:
        early_stopping = EarlyStopping(model.monitor(), mode=model.monitor_mode(), min_delta=0.0, patience=50)
        callbacks.append(early_stopping)
    trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg['trainer'])

if load:
    if ckpt_file is None:
        print(f'No ckpt files in {save_dir_ckpt}')
    else:
        print(f'\nLoading from: {ckpt_file}')
        if is_training:
            trainer = pl.Trainer(logger=logger, callbacks=callbacks, resume_from_checkpoint=ckpt_file,
                             **cfg['trainer'])
        else:

            model = model.load_from_checkpoint(ckpt_file, **model_params)
            evaluator.set_model(model)
            model.set_my_evaluator(evaluator=evaluator)

            if cfg['model']['name'] in [Cte.VACA_PIWAE, Cte.VACA, Cte.MCVAE]:
                model.set_random_train_sampler(data_module.get_random_train_sampler())


# In[ ]:


if is_training:
    trainer.fit(model, data_module)
    # save_yaml(model.get_arguments(), file_path=os.path.join(save_dir, 'hparams_model.yaml'))
    argtools.save_yaml(cfg, file_path=os.path.join(save_dir, 'hparams_full.yaml'))
    # %% Testing


# # TESTING

# In[ ]:


model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = int(sum([p.numel() for p in model_parameters]))

model.eval()
model.freeze()  # IMPORTANT


# In[ ]:


output_valid = model.evaluate(dataloader=data_module.val_dataloader(),
                              name='valid',
                              save_dir=save_dir,
                              plots=False)


# In[ ]:


output_test = model.evaluate(dataloader=data_module.test_dataloader(),
                             name='test',
                             save_dir=save_dir,
                             plots=True)
output_valid.update(output_test)

output_valid.update(argtools.flatten_cfg(cfg))


# In[ ]:


import json
output_valid.update({'ckpt_file': ckpt_file,
                     'num_parameters': params})

with open(os.path.join(save_dir, 'output.json'), 'w') as f:
    json.dump(output_valid, f)
print(f'Experiment folder: {save_dir}')


# # Custom interventions

# In[ ]:


bs = data_module.batch_size
data_module.batch_size = 1
x_I = {'x1': 2.4721} # Intervention before normalizing
x_I = {'x1': 0.0}  # Intervention before normalizing
data_loader = data_module.test_dataloader()
data_loader.dataset.set_intervention(x_I)
data_loader = data_module.test_dataloader()
data_module.batch_size = bs

batch = next(iter(data_loader))



print(batch)


# In[ ]:


x_hat, z = model.get_intervention(batch,
                         x_I=data_loader.dataset.x_I,
                         nodes_list=data_loader.dataset.nodes_list,
                         return_type = 'sample', # mean or sample
                         use_aggregated_posterior = False,
                         normalize = True)

print(f"Original: {batch.x.flatten()}")
print(f"Intervened: {batch.x_i.flatten()}")
print(f"Reconstructed: {x_hat.flatten()}")


# # Custom counterfactuals

# In[ ]:


bs = data_module.batch_size
data_module.batch_size = 1
x_I = {'x1': 2.4721} # Intervention before normalizing
x_I = {'x1': 0.0}  # Intervention before normalizing
data_loader = data_module.test_dataloader()
data_loader.dataset.set_intervention(x_I)
data_loader = data_module.test_dataloader()
data_module.batch_size = bs

batch = next(iter(data_loader))



print(batch)


# In[ ]:


x_CF, z_factual, z_cf_I, z_dec = model.compute_counterfactual(batch=batch,
                                        x_I=data_loader.dataset.x_I,
                                        nodes_list=data_loader.dataset.nodes_list,
                                        normalize=True,
                                        return_type='sample')

print(f"Original: {batch.x.flatten()}")
print(f"Counterfactual: {batch.x_i.flatten()}")
print(f"Reconstructed: {x_CF.flatten()}")

