import argparse
import os, sys
from datetime import timedelta
from functools import partial
from pathlib import Path

from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_path, '../model'))
sys.path.append(os.path.join(dir_path, '../data'))
sys.path.append(os.path.join(dir_path, '../expt'))

import numpy as np

from common import MetricsCallback, LagrangeStart, LitProgressBar
from counterfactuals import create_uniform_gauss_atd
from load_dag import gen_params
from pytorch_lightning import Trainer, seed_everything
from load_data import gen_data
from sinkhorn_gn import SinkhornGN
from utils import get_diam_bounds

import pandas as pd

from average_derivative_expt import get_best_hyper_params

NUM_WORKER = 4
BATCH_SIZE = 4096

def get_d_values(setting):
    values = np.arange(-2, 2, 10)
    ade1_value = 0
    if setting == 'bow':
        values = np.linspace(-3, 2, 10)
        ade1_value = 4
    elif setting == 'iv-strong':
        values = np.linspace(-6, 6, 10)
        ade1_value = 9
    elif setting == 'backdoor-nonlinear':
        values = np.linspace(-6, 3, 10)
        ade1_value = 6
    elif setting == 'iv-linear':
        values = np.linspace(-6, 2, 10)
        ade1_value = 4
    elif setting == 'iv-weak':
        values = np.linspace(-15, 5, 10)
        ade1_value = 10
    elif setting == 'leaky':
        values = np.linspace(-6, 3, 10)
        ade1_value = 6

    return ade1_value, values


def get_dose_estimation(args):
    seed_everything(args.seed, workers=True)
    hyper = get_best_hyper_params(args.setting)

    name, graph, do_var, target_var, n_latent, latent_dim = gen_params(hyper['causal_graph'])
    data_args = {'linear': hyper['linear'], 'n_samples': args.sample_size, 'batch_size': BATCH_SIZE,
                 'num_workers': NUM_WORKER,
                 'validation_size': 0.1}

    if name == 'iv':
        data_args['weak'] = hyper['weak']

    res = gen_data(name, data_args)
    data, dm, var_dims, true_atd = res['data'], res['dm'], res['var_dims'], res['true_atd']
    diam, lower_bound, upper_bound = get_diam_bounds(data, var_dims)

    d1_value, values = get_d_values(args.setting)
    d0_value = values[args.d0_index]

    param_fn = create_uniform_gauss_atd(do_var, target_var, var_dims, delta=0.1, d1_value=d1_value, d0_value=d0_value)

    model = SinkhornGN(param_fn, graph, var_dims, n_latent, latent_dim,
                       upper_bounds=None, lower_bounds=None,
                       diameter=diam, n_hidden=hyper['n_hidden'], n_layers=hyper['n_layers'],
                       lr=hyper['lr'], lagrange_lr=hyper['lagrange_lr'],
                       do_var=do_var, ade_d1_value=d1_value, ade_d0_value=d0_value, target_var=target_var)

    metrics_callback = MetricsCallback()
    lagrange_start = LagrangeStart(monitor='val_dist', min_delta=0.001,
                                   patience=args.max_epochs // 20 + 10, verbose=True, mode='min')
    prog_bar = LitProgressBar(refresh_rate=20)
    lr_monitor = LearningRateMonitor()
    args.callbacks.extend([metrics_callback, prog_bar, lagrange_start, lr_monitor])
    trainer = Trainer.from_argparse_args(args, log_every_n_steps=5)

    trainer.fit(model, dm)

    model.to('cpu')
    dist_min = 0.
    dist_max = 0.
    max_ade, min_ade = 0., 0.
    try:
        tol = model.tol.detach().to('cpu').numpy()
        for p, d in [('min_ade', 'dist_min'), ('max_ade', 'dist_max')]:
            params = np.array([i[p] for i in metrics_callback.metrics if p in i])
            dist = np.array([i[d] for i in metrics_callback.metrics if d in i])
            if p == 'min_ade':
                min_ade = min(params[dist <= 1.15 * tol])
                dist_min = dist[-1]
            elif p == 'max_ade':
                max_ade = max(params[dist <= 1.15 * tol])
                dist_max = dist[-1]
    except:
        tol = 0.

    df = pd.DataFrame(
        {'causal_graph': hyper['causal_graph'], 'linear': hyper['linear'], 'weak': hyper['weak'],
         'd0_value': d0_value, 'd1_value': d1_value, 'min_ade': min_ade,
         'max_ade': max_ade, 'seed': args.seed, 'n_samples': args.sample_size, 'n_hidden': hyper['n_hidden'],
         'n_layers': hyper['n_layers'],
         'dist_min': dist_min, 'dist_max': dist_max, 'tol': tol, 'setting': args.setting}, index=[0])

    out_dir = os.path.abspath(args.output_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    df.to_csv(os.path.join(out_dir, 'ate_estimation.csv'), index=False, header=False, mode='a')


def build_args(arg_defaults=None):
    tmp = arg_defaults
    arg_defaults = {
        'n_hidden': 64,
        'n_layers': 3,
        'lr': 0.05,
        'lagrange_lr': 0.5,
        'max_epochs': 1000,
        'seed': 5963,
        'gpus': 1,
        'callbacks': [],
        'sample_size': 5000,
        'gradient_clip_val': 0.5,
        'output_dir': 'results/'
    }

    if tmp is not None:
        arg_defaults.update(tmp)

    parser = argparse.ArgumentParser()
    parser.add_argument('--d0_index', required=False, type=int)
    parser.add_argument('--max_epochs', required=False, type=int)
    parser.add_argument('--seed', required=False, type=int)
    parser.add_argument('--setting', required=False, type=str)
    parser.add_argument('--output_dir', required=False, type=str)
    parser = SinkhornGN.add_model_specific_args(parser)
    parser.set_defaults(**arg_defaults)
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = build_args()
    get_dose_estimation(args)
