import argparse
import os, sys

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_path, '../model'))
sys.path.append(os.path.join(dir_path, '../data'))
sys.path.append(os.path.join(dir_path, '../expt'))

import torch
import numpy as np
from pytorch_lightning.loggers import WandbLogger

from common import MetricsCallback, LagrangeStart, LitProgressBar
from counterfactuals import create_atd
from load_dag import gen_params
from pytorch_lightning import Trainer, seed_everything
from load_data import gen_data
from sinkhorn_gn import SinkhornGN
from utils import get_diam_bounds

import pandas as pd

NUM_WORKER = 4
BATCH_SIZE = 4096


def get_best_hyper_params(setting):
    hyper_params = {
        'n_hidden': 64,
        'n_layers': 3,
        'lr': 0.001,
        'lagrange_lr': 1,
        'weak': False
    }

    if setting == 'backdoor-linear':
        hyper_params['n_layers'] = 3
        hyper_params['lagrange_lr'] = 1
        hyper_params['causal_graph'] = 'backdoor'
        hyper_params['linear'] = True
    elif setting == 'backdoor-nonlinear':
        hyper_params['lagrange_lr'] = 1
        hyper_params['causal_graph'] = 'backdoor'
        hyper_params['linear'] = False
    elif setting == 'leaky':
        hyper_params['lagrange_lr'] = 1
        hyper_params['causal_graph'] = 'leaky'
        hyper_params['linear'] = True
    elif setting == 'iv-weak':
        hyper_params['lagrange_lr'] = 0.1
        hyper_params['causal_graph'] = 'iv'
        hyper_params['linear'] = False
        hyper_params['weak'] = True
    elif setting == 'iv-strong':
        hyper_params['causal_graph'] = 'iv'
        hyper_params['linear'] = False
        hyper_params['weak'] = False
    elif setting == 'iv-linear':
        hyper_params['lagrange_lr'] = 1
        hyper_params['causal_graph'] = 'iv'
        hyper_params['linear'] = True
    elif setting == 'frontdoor':
        hyper_params['causal_graph'] = 'frontdoor'
        hyper_params['linear'] = False
    elif setting == 'bow':
        hyper_params['lagrange_lr'] = 1
        hyper_params['causal_graph'] = 'bow'
        hyper_params['linear'] = False

    return hyper_params


def get_atd_estimation(args):
    seed_everything(args.seed, workers=True)
    hyper = get_best_hyper_params(args.setting)
    name, graph, do_var, target_var, n_latent, latent_dim = gen_params(hyper['causal_graph'])
    data_args = {'linear': hyper['linear'], 'n_samples': args.sample_size, 'batch_size': BATCH_SIZE,
                 'num_workers': NUM_WORKER,
                 'validation_size': 0.1}

    if name == 'iv':
        data_args['weak'] = hyper['weak']

    res = gen_data(name, data_args)
    data, dm, var_dims, true_atd = res['data'], res['dm'], res['var_dims'], res['true_atd']
    diam, lower_bound, upper_bound = get_diam_bounds(data, var_dims)

    param_fn = create_atd(do_var, target_var, var_dims, delta=0.1)

    model = SinkhornGN(param_fn, graph, var_dims, n_latent, latent_dim,
                       upper_bounds=None, lower_bounds=None,
                       diameter=diam, n_hidden=hyper['n_hidden'], n_layers=hyper['n_layers'],
                       lr=hyper['lr'], lagrange_lr=hyper['lagrange_lr'])

    metrics_callback = MetricsCallback()
    lagrange_start = LagrangeStart(monitor='val_dist', min_delta=0.001,
                                   patience=args.max_epochs // 10, verbose=True, mode='min')
    prog_bar = LitProgressBar(refresh_rate=20)
    args.callbacks.extend([metrics_callback, prog_bar, lagrange_start])
    trainer = Trainer.from_argparse_args(args, log_every_n_steps=5)

    trainer.fit(model, dm)

    model.to('cpu')
    dist_min = 0.
    dist_max = 0.
    tol = model.tol.detach().to('cpu').numpy()
    max_atd, min_atd = 0., 0.
    for p, d in [('param_min', 'dist_min'), ('param_max', 'dist_max')]:
        params = np.array([i[p] for i in metrics_callback.metrics if p in i])
        dist = np.array([i[d] for i in metrics_callback.metrics if d in i])
        if p == 'param_min':
            min_atd = min(params[dist <= tol])
            dist_min = dist[-1]
        elif p == 'param_max':
            max_atd = max(params[dist <= tol])
            dist_max = dist[-1]

    df = pd.DataFrame(
        {'causal_graph': hyper['causal_graph'], 'linear': hyper['linear'], 'weak': hyper['weak'],
         'true_atd': true_atd, 'min_atd': min_atd,
         'max_atd': max_atd, 'seed': args.seed, 'n_samples': args.sample_size, 'n_hidden': hyper['n_hidden'],
         'n_layers': hyper['n_layers'],
         'dist_min': dist_min, 'dist_max': dist_max, 'tol': tol, 'setting': args.setting}, index=[0])

    out_dir = os.path.abspath(args.output_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    df.to_csv(os.path.join(out_dir, 'atd_estimation.csv'), index=False, header=False, mode='a')


def build_args(arg_defaults=None):
    tmp = arg_defaults
    arg_defaults = {
        'n_hidden': 64,
        'n_layers': 3,
        'lr': 0.05,
        'lagrange_lr': 0.5,
        'max_epochs': 500,
        'seed': 1929,
        'gpus': 1,
        'output_dir': 'results/',
        'callbacks': [],
        'gradient_clip_val': 0.5,
        'sample_size': 5000
    }
    if tmp is not None:
        arg_defaults.update(tmp)

    parser = argparse.ArgumentParser()
    parser.add_argument('--max_epochs', required=False, type=int)
    parser.add_argument('--sample_size', required=False, type=int)
    parser.add_argument('--seed', required=False, type=int)
    parser.add_argument('--setting', required=False, type=str)
    parser.add_argument('--output_dir', required=False, type=str)
    parser = SinkhornGN.add_model_specific_args(parser)
    parser.set_defaults(**arg_defaults)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = build_args()
    get_atd_estimation(args)
