import argparse
import json
import os
import warnings

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

import utils.args_parser  as argtools
import utils.tools as utools
from utils.constants import Cte
import numpy as np

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--dataset_file', default='_params/dataset_toy.yaml', type=str, help='path to configuration file for the dataset')
parser.add_argument('--model_file', default='_params/model_mcvae.yaml', type=str, help='path to configuration file for the dataset')
parser.add_argument('--trainer_file', default='_params/trainer.yaml', type=str,  help='path to configuration file for the training')
parser.add_argument('--yaml_file', default='', type=str, help='path to trained model configuration')
parser.add_argument('-d', '--dataset_dict', action=argtools.StoreDictKeyPair, metavar="KEY1=VAL1,KEY2=VAL2...", help='manually define dataset configurations as string: KEY1=VALUE1+KEY2=VALUE2+...')
parser.add_argument('-m', '--model_dict', action=argtools.StoreDictKeyPair, metavar="KEY1=VAL1,KEY2=VAL2...", help='manually define model configurations as string: KEY1=VALUE1+KEY2=VALUE2+...')
parser.add_argument('-o', '--optim_dict', action=argtools.StoreDictKeyPair, metavar="KEY1=VAL1,KEY2=VAL2...", help='manually define optimizer configurations as string: KEY1=VALUE1+KEY2=VALUE2+...')
parser.add_argument('-t', '--trainer_dict', action=argtools.StoreDictKeyPair, metavar="KEY1=VAL1,KEY2=VAL2...", help='manually define trainer configurations as string: KEY1=VALUE1+KEY2=VALUE2+...')
parser.add_argument('-s', '--seed', default=-1, type=int, help='set random seed, default: random')
parser.add_argument('-r', '--root_dir', default='', type=str, help='directory for storing results')
parser.add_argument('-i', '--is_training', default=1, type=int, help='run with training (1) or without training (0)')
parser.add_argument('-f', '--eval_fair', default=False,  action="store_true", help='run code with counterfactual fairness experiment (only for German dataset), default: False')
parser.add_argument('--show_results', default=1, type=int, help='run with evaluation (1) or without(0), default: 1')
parser.add_argument('--cf_sample', default=False,  action="store_true", help='evaluate performance for on one counterfactual sample')

parser.add_argument('--plots', default=0, type=int, help='run code with plotting (1) or without (0), default: 0')

args = parser.parse_args()

# %%
if args.yaml_file == '':
    cfg = argtools.parse_args(args.dataset_file)
    cfg.update(argtools.parse_args(args.model_file))
    cfg.update(argtools.parse_args(args.trainer_file))
else:
    cfg = argtools.parse_args(args.yaml_file)
if len(args.root_dir) > 0:  cfg['root_dir'] = args.root_dir
if int(args.seed) >= 0:
    cfg['seed'] = int(args.seed)

# %%
pl.seed_everything(cfg['seed'])
if args.dataset_dict is not None: cfg['dataset']['params2'].update(args.dataset_dict)
if args.model_dict is not None: cfg['model']['params'].update(args.model_dict)
if args.optim_dict is not None: cfg['optimizer']['params'].update(args.optim_dict)
if args.trainer_dict is not None: cfg['trainer'].update(args.trainer_dict)

if isinstance(cfg['trainer']['gpus'], int):
    cfg['trainer']['auto_select_gpus'] = False
    cfg['trainer']['gpus'] = -1

cfg['dataset']['params'] = cfg['dataset']['params1'].copy()
cfg['dataset']['params'].update(cfg['dataset']['params2'])
print(args.dataset_dict)
print(cfg['dataset']['params'])
print(cfg['model']['params'])

# %% Load dataset

data_module = None

if cfg['dataset']['name'] in [Cte.COLLIDER, Cte.TRIANGLE, Cte.LOAN, Cte.MGRAPH, Cte.CHAIN]:

    from data_modules.toy_scm import ToySCMDataModule

    dataset_params = cfg['dataset']['params'].copy()
    dataset_params['dataset_name'] = cfg['dataset']['name']

    data_module = ToySCMDataModule(**dataset_params)
    data_module.prepare_data()
elif cfg['dataset']['name'] in [Cte.GERMAN]:

    from data_modules.real_scm import RealSCMDataModule

    dataset_params = cfg['dataset']['params'].copy()
    dataset_params['dataset_name'] = cfg['dataset']['name']

    data_module = RealSCMDataModule(**dataset_params)
    data_module.prepare_data()

assert data_module is not None, cfg['dataset']

# %% Load model
model = None
model_params = cfg['model']['params'].copy()
# utools.blockPrint()

# VCAUSE
if cfg['model']['name'] == Cte.VCAUSE:
    from models.vcause.vcause import VCAUSE

    model_params['x_dim'] = data_module.num_features
    model_params['deg'] = data_module.get_deg(indegree=True)
    model_params['num_nodes'] = data_module.num_nodes
    model_params['edge_dim'] = data_module.edge_dimension
    model_params['intervention_list'] = data_module.train_dataset.get_intervention_list(True)
    model_params['scaler'] = data_module.scaler
    model_params['data_is_toy'] = data_module.data_is_toy
    if 'is_heterogeneous' in  model_params and  model_params['is_heterogeneous']:
        model_params['x_dim'] = data_module.num_features_list
        model_params['distr_x'] = data_module.likelihood_list

    model = VCAUSE(**model_params)
    model.set_random_train_sampler(data_module.get_random_train_sampler())
# VCAUSE with PIWAE
elif cfg['model']['name'] == Cte.VCAUSE_PIWAE:
    from models.vcause.vcause_piwae import VCAUSE_PIWAE

    model_params['x_dim'] = data_module.num_features
    model_params['deg'] = data_module.get_deg(indegree=True)
    model_params['edge_dim'] = data_module.edge_dimension
    model_params['num_nodes'] = data_module.num_nodes

    model_params['intervention_list'] = data_module.train_dataset.get_intervention_list(True)
    model_params['scaler'] = data_module.scaler
    model_params['data_is_toy'] = data_module.data_is_toy

    if 'is_heterogeneous' in  model_params and  model_params['is_heterogeneous']:
        model_params['x_dim'] = data_module.num_features_list
        model_params['distr_x'] = data_module.likelihood_list
    model = VCAUSE_PIWAE(**model_params)
    model.set_random_train_sampler(data_module.get_random_train_sampler())



# MultiCVAE
elif cfg['model']['name'] == Cte.MCVAE:
    from models.multi_cvae.cvae_multi import MCVAE

    model_params['x_dim'] = data_module.num_features
    model_params['intervention_list'] = data_module.train_dataset.get_intervention_list(True)
    model_params['topological_nodes'] = data_module.topological_nodes
    model_params['topological_parents'] = data_module.topological_parents
    model_params['scaler'] = data_module.scaler
    model_params['num_epochs_per_nodes'] = int(np.floor((cfg['trainer']['max_epochs']/ len(data_module.topological_nodes))))
    model_params['equations_type'] = data_module.equations_type
    model_params['dataset_name'] = data_module.dataset_name
    model = MCVAE(**model_params)
    model.set_random_train_sampler(data_module.get_random_train_sampler())
    cfg['early_stopping'] = False

# CAREFL
elif cfg['model']['name'] == Cte.CARELF:
    from models.carefl.carefl import CAREFL

    model_params['num_nodes'] = data_module.num_nodes
    model_params['intervention_list'] = data_module.train_dataset.get_intervention_list(True)
    model_params['scaler'] = data_module.scaler
    model = CAREFL(**model_params)
assert model is not None
utools.enablePrint()

model.summarize()
model.set_optim_params(optim_params=cfg['optimizer'],
                       sched_params=cfg['scheduler'])

# %% Prepare training
if args.yaml_file == '':
    save_dir = argtools.mkdir(os.path.join(cfg['root_dir'],
                                       argtools.get_experiment_folder(cfg),
                                       str(cfg['seed'])))
else:
    save_dir = os.path.join(*args.yaml_file.split('/')[:-1])
print(f'Save dir: {save_dir}')
# trainer = pl.Trainer(**cfg['model'])
logger = TensorBoardLogger(save_dir=save_dir, name='logs', default_hp_metric=False)
out = logger.log_hyperparams(argtools.flatten_cfg(cfg))

save_dir_ckpt = argtools.mkdir(os.path.join(save_dir, 'ckpt'))
ckpt_file = argtools.newest(save_dir_ckpt)
callbacks = []
if args.is_training == 1:

    checkpoint = ModelCheckpoint(period=1,
                                 monitor=model.monitor(),
                                 mode=model.monitor_mode(),
                                 save_top_k=1,
                                 filename='checkpoint-{epoch:02d}',
                                 dirpath=save_dir_ckpt)

    callbacks = [checkpoint]

    if cfg['early_stopping']:
        early_stopping = EarlyStopping(model.monitor(), mode=model.monitor_mode(), min_delta=0.0, patience=50)
        callbacks.append(early_stopping)

    if ckpt_file is not None:
        print(f'Loading model training: {ckpt_file}')
        trainer = pl.Trainer(logger=logger, callbacks=callbacks, resume_from_checkpoint=ckpt_file,
                             **cfg['trainer'])
    else:

        trainer = pl.Trainer(logger=logger, callbacks=callbacks, **cfg['trainer'])


    # %% Train

    trainer.fit(model, data_module)
    # save_yaml(model.get_arguments(), file_path=os.path.join(save_dir, 'hparams_model.yaml'))
    argtools.save_yaml(cfg, file_path=os.path.join(save_dir, 'hparams_full.yaml'))
    # %% Testing

else:
    # %% Testing
    trainer = pl.Trainer()

    model = model.load_from_checkpoint(ckpt_file)
    if  cfg['model']['name'] in [Cte.VCAUSE_PIWAE, Cte.VCAUSE, Cte.MCVAE]:
        model.set_random_train_sampler(data_module.get_random_train_sampler())

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = int(sum([np.prod(p.size()) for p in model_parameters]))

model.eval()
model.freeze()  # IMPORTANT

if args.show_results:

    output_valid = model.my_evaluator(dataloader=data_module.val_dataloader(),
                                            name='valid',
                                            save_dir=save_dir,
                                            plots=False)
    output_test = model.my_evaluator(dataloader=data_module.test_dataloader(),
                                           name='test',
                                           save_dir=save_dir,
                                           plots=args.plots)
    output_valid.update(output_test)

    output_valid.update(argtools.flatten_cfg(cfg))
    output_valid.update({'ckpt_file': ckpt_file,
                         'num_parameters': params})

    with open(os.path.join(save_dir, 'output.json'), 'w') as f:
        json.dump(output_valid, f)
    print(f'Experiment folder: {save_dir}')

########################################################################################################################
#                                                                                                                      #
#                                                      CF SAMPLE                                                                                                                      #
#                                                                                                                      #
########################################################################################################################
if args.cf_sample:

    # %%
    import torch
    import pandas as pd
    print('Dataframe with counterfactuals')
    data_module.batch_size = 56



    list_i = np.linspace(-2.0, 2.0, num=30)

    intervention_list = data_module.train_dataset.get_intervention_list(std_list=list_i,
                                                                        node_list=[1, 2])
    is_noise = False
    data = np.array([np.arange(len(list_i))]).flatten()
    df = pd.DataFrame(index=data)



    id_experiment = '_'.join([cfg['dataset']['name'], cfg['dataset']['params']['equations_type'],  cfg['model']['name']])
    df['dataset'] =  cfg['dataset']['name']
    df['SEM'] =  cfg['dataset']['params']['equations_type']
    df['model'] = cfg['model']['name']
    df['do_values'] = list_i

    for x_I, norm_I_value_str in intervention_list:
        node_id = list(x_I.keys())[0]
        value_I = float(norm_I_value_str.split('_')[0])
        idx_i = list(list_i).index(value_I)

        inter_id = f'do_{node_id}'
        inter_str = f'do({(x_I)}) Add noise: {is_noise}'

        data_loader = data_module.test_dataloader()
        data_loader.dataset.set_intervention(x_I, is_noise=is_noise)
        print(f'do({x_I}) is_noise: {is_noise}')
        batch = next(iter(data_loader))
        x_CF, z_factual, z_intervened = model.compute_counterfactual(batch=batch,
                                                                     x_I=data_loader.dataset.x_I,
                                                                     z_I={})
        x_CF_real = data_loader.dataset.get_counterfactual(x_factual=data_module.scaler.inverse_transform(batch.x.view(batch.num_graphs, -1)),
                                                           u_factual=batch.u.view(batch.num_graphs, -1),
                                                           x_I=x_I,
                                                           is_noise=is_noise)
        x_CF_real = data_module.scaler.transform(x_CF_real)
        data_loader.dataset.clean_intervention()

        import seaborn as sns
        import pandas as pd
        import matplotlib

        # x = data_module.scaler.inverse_transform(batch.x.view(batch.num_graphs, -1)).numpy()
        x = batch.x.view(batch.num_graphs, -1).numpy()

        mse = torch.sqrt(((x_CF - x_CF_real) ** 2)).mean(1).numpy()
        # x_CF_real =data_module.scaler.inverse_transform(x_CF_real).numpy()
        # x_CF = data_module.scaler.inverse_transform(x_CF).numpy()
        x_CF_real =x_CF_real.numpy()
        x_CF = x_CF.numpy()
        # df[f'mse_{inter_id}'] = mse
        num_samples, dim_sample = x_CF.shape


        for idx_s in range(num_samples):
            for idx_d in range(dim_sample):
                df.at[idx_i, f'{inter_id}_x_cf_{idx_s}_x{idx_d+1}_real']  = x_CF_real[idx_s, idx_d]
                df.at[idx_i, f'{inter_id}_x_cf_{idx_s}_x{idx_d+1}_gener'] = x_CF[idx_s, idx_d]
                df.at[idx_i, f'{inter_id}_x_factual_{idx_s}_x{idx_d+1}'] = x[idx_s, idx_d]

    df.to_pickle(os.path.join('_dataframes', f'cf_sample_{id_experiment}.pkl'))


if args.eval_fair:
    assert cfg['dataset']['name'] in [Cte.GERMAN], "counterfactual fairness not implemented for dataset"

    output_fairness = model.my_cf_fairness(data_module=data_module,
                                           save_dir=save_dir)


    output_fairness.update(argtools.flatten_cfg(cfg))
    output_fairness.update({'ckpt_file': ckpt_file,
                         'num_parameters': params})

    with open(os.path.join(save_dir, 'fairness.json'), 'w') as f:
        json.dump(output_fairness, f)
    print(f'Experiment folder: {save_dir}')







