import os
import time
import logging

import csv
import json

import pandas as pd
import numpy as np

from statsmodels.nonparametric.kernel_regression import KernelReg
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import scipy.stats

import doubleml as dml
from doubleml.rdd import RDFlex
from doubleml.rdd.datasets.area_yield_dgp import dgp_area_yield
from rdrobust import rdrobust


logger = logging.getLogger(__name__)


def generate_data(dgp_params, setting, rep_seed, n_obs):
    data = dgp_area_yield(
        **{**dgp_params, 'seed': rep_seed, 'n_obs': n_obs}
    )
    pca_obj = PCA(2, random_state=rep_seed)
    mean_cps = np.stack([
        np.mean(data['X'][:, :, 0], axis=1),
        np.mean(data['X'][:, :, 1], axis=1)
    ], axis=1)
    pca_mean = pca_obj.fit_transform(mean_cps)

    df = pd.DataFrame(dict(
        outcome=data['Y'],
        assigned_treatment=data['T'],
        actual_treatment=data['D'],
        score_distance=data['score_distance'],
        score_improvement=data['score_improvement'],
        # for the oracle estimatoins
        outcome_nottreated=data['Y0'],
        outcome_treated=data['Y1'],
        complier=data["complier"],
        # covaraiates
        stddev_Cx=np.std(data['X'][:, :, 0], axis=1),
        stddev_Cy=np.std(data['X'][:, :, 1], axis=1),
        skewness_Cx=scipy.stats.skew(data['X'][:, :, 0], axis=1),
        skewness_Cy=scipy.stats.skew(data['X'][:, :, 1], axis=1),
        kurtosis_Cx=scipy.stats.kurtosis(data['X'][:, :, 0], axis=1),
        kurtosis_Cy=scipy.stats.kurtosis(data['X'][:, :, 1], axis=1),
        Mean_PCA_0=pca_mean[:, 0],
        Mean_PCA_1=pca_mean[:, 1],
        mean_Cx=np.mean(data["X"][:, :, 0], axis=1),
        min_Cx=np.min(data["X"][:, :, 0], axis=1),
        max_Cx=np.max(data["X"][:, :, 0], axis=1),
        mean_Cy=np.mean(data["X"][:, :, 1], axis=1),
        min_Cy=np.min(data["X"][:, :, 1], axis=1),
        max_Cy=np.max(data["X"][:, :, 1], axis=1),
    ))

    # set cutoffs according to dgp settings
    setting = {
        **setting,
        'distance_cutoff': dgp_params['treatment_dist'],
        'yield_cutoff': dgp_params['treatment_improvement']
    }
    return rdd_setting(df, **setting)


def dgp_datagen(
    dgp_params,
    n_obs,
    n_rep,
    rnd
):
    rep_seeds = [rnd.integers(0, 2**32 - 1) for _ in range(n_rep)]

    def generator(setting):
        for rep_seed in rep_seeds:
            yield generate_data(dgp_params, setting, rep_seed, n_obs), rep_seed
    return generator


def rdd_setting(
    df,
    remove_unexplained,
    subrule,
    score,
    distance_cutoff,
    yield_cutoff,
    covariates=(),
):
    if remove_unexplained:
        df = df[df.assigned_treatment == df.actual_treatment]

    if subrule:
        if score == 'distance':
            df = df[df.score_improvement > 0]
        elif score == 'yieldimprovement':
            df = df[df.score_distance >= distance_cutoff]
        else:
            raise ValueError('invalid score for subruling')

    debug = dict(
      score_distance=df.score_distance.values,
      score_improvement=df.score_improvement.values,
    )

    complier = None
    if score == 'distance':
        score = df.score_distance.values
        cutoff = distance_cutoff
        complier = (df.score_improvement > yield_cutoff) & df.complier
    elif score == 'yieldimprovement':
        score = df.score_improvement.values
        cutoff = yield_cutoff
        complier = (df.score_distance > distance_cutoff) & df.complier
    elif score == 'boundarymindist':
        boundary = np.array([distance_cutoff, yield_cutoff])
        distance = np.column_stack([
            df.score_distance,
            df.score_improvement
        ])
        distance = distance - boundary
        scaler = StandardScaler(with_mean=False)
        distance = scaler.fit_transform(distance)
        debug['boundary_distance'] = distance
        score = np.min(distance, axis=1)
        cutoff = 0.0
        complier = df.complier
    elif score == 'boundarylinf':
        boundary = np.array([distance_cutoff, yield_cutoff])
        distance = np.column_stack([
            df.score_distance,
            df.score_improvement
        ])
        distance = distance - boundary
        scaler = StandardScaler(with_mean=False)
        distance = scaler.fit_transform(distance)
        score = np.sign(np.min(distance, axis=1)) * np.max(abs(distance), axis=1)
        cutoff = 0.0
        complier = df.complier
    elif score == "boundarypoint":
        point = [[0, 0.05]]
        boundary = np.array([distance_cutoff, yield_cutoff])
        score_2d = np.column_stack([
            df.score_distance,
            df.score_improvement
        ])
        score_2d = score_2d - boundary
        scaler = StandardScaler(with_mean=False)
        score_2d_norm = scaler.fit_transform(score_2d)
        point_norm = scaler.transform(point)
        distance_2d = score_2d_norm - point_norm
        distance = np.sqrt(np.sum(distance_2d**2, axis=1))
        score = distance * (2*df.actual_treatment - 1)
        cutoff = 0.0
        complier = df.complier
    else:
        raise ValueError('invalid score')

    result_data = dict(
      covariates=df[[
        'stddev_Cx', 'stddev_Cy',
        'skewness_Cx', 'skewness_Cy',
        'kurtosis_Cx', 'kurtosis_Cy',
        "min_Cx", "max_Cx", "mean_Cy", "min_Cy", "max_Cy",
        *covariates
      ]].values,
      score=score,
      outcome=df.outcome.values,
      treatment_assigned=df.assigned_treatment.values,
      treatment_actual=df.actual_treatment.values,
      complier=complier,
      cutoff=cutoff,
      debug=debug
    )

    if 'outcome_nottreated' in df.columns and 'outcome_treated' in df.columns:
        # add counterfactuals if present (for the oracles)
        result_data['outcome_treated'] = df.outcome_treated.to_numpy()
        result_data['outcome_nottreated'] = df.outcome_nottreated.to_numpy()

    return result_data


class CsvResultsWriter:
    def __init__(self, path, cols):
        self.path = path
        self.cols = cols

    def init(self):
        with open(self.path, "w", encoding='utf-8') as output_csv:
            writer = csv.DictWriter(output_csv, fieldnames=self.cols)
            writer.writeheader()

    def write(self, row):
        with open(self.path, "a", encoding='utf-8') as output_csv:
            writer = csv.DictWriter(output_csv, fieldnames=self.cols)
            writer.writerow(row)


class JsonResultsWriter:
    def __init__(self, path, basename):
        self.path = path
        self.basename = basename
        self.filecnt = 0

    def init(self):
        self.filecnt = 0

    def convert_dtypes(self, value):
        if type(value) is np.float64:
            return float(value)
        if type(value) is np.int64:
            return int(value)
        return value

    def write(self, row):
        row = {
          key: self.convert_dtypes(value)
          for key, value in row.items()
        }
        fname = os.path.join(self.path, f'{self.basename}_{self.filecnt}.json')
        with open(fname, "w", encoding='utf-8') as json_output:
            json.dump(row, json_output)
        self.filecnt += 1


def benchmark_setting(
    data_generator,
    setting_name,
    setting_values,
    methods,
    results_writer,
):
    for rep, (data, rep_seed) in enumerate(data_generator(setting_values['data'])):
        logger.info(f'data sample {rep}: setting {setting_name} seed {rep_seed}')
        dml_data = dml.DoubleMLData.from_arrays(
            x=data['covariates'],
            y=data['outcome'],
            d=data['treatment_actual'],
            s=data['score']
        )
        cutoff = data['cutoff']

        # oracle
        if (
            'outcome_nottreated' in data and 'outcome_treated' in data
        ):
            logger.info(f"Benchmarking setting {setting_name} using oracles...")

            # neighborhood oracle
            results_writer.write({
                **benchmark_oracle_neighbour(
                    'oracle_neighbour',
                    setting_name,
                    setting_values,
                    data,
                    cutoff
                ),
                'rep': rep,
                'rep_seed': rep_seed
            })

            # kernel oracle
            results_writer.write({
                **benchmark_oracle_kernel(
                    'oracle_kernel',
                    setting_name,
                    setting_values,
                    data,
                    cutoff
                ),
                'rep': rep,
                'rep_seed': rep_seed
            })

        logger.info(f"Benchmarking setting {setting_name} using rdrobust...")
        # rdrobust nocovs
        data_nocovs = data.copy()
        data_nocovs['covariates'] = None
        results_writer.write({
            **benchmark_rdrobust(
                'rdrobustnocovs',
                setting_name,
                setting_values,
                data_nocovs,
                cutoff
            ),
            'rep': rep,
            'rep_seed': rep_seed
        })

        # rdrobust
        results_writer.write({
            **benchmark_rdrobust(
                 'rdrobust',
                 setting_name,
                 setting_values,
                 data,
                 cutoff
             ),
            'rep': rep,
            'rep_seed': rep_seed
        })

        # rdflex methods
        for method_name, method in methods.items():
            logger.info(f"Benchmarking setting {setting_name} using {method_name}...")
            results_writer.write({
                **benchmark_dml_method(
                    method_name, method, setting_name, setting_values, dml_data, cutoff
                ),
                'rep': rep,
                'rep_seed': rep_seed
            })


def benchmark_oracle_neighbour(
    method_name,
    setting_name,
    setting_values,
    data,
    cutoff
):
    score = data['score']
    outcome_treated = data['outcome_treated']
    outcome_nottreated = data['outcome_nottreated']

    fuzzy = setting_values['params']['fuzzy']
    if fuzzy:
        score = score[data['complier']]
        outcome_treated = outcome_treated[data['complier']]
        outcome_nottreated = outcome_nottreated[data['complier']]
    else:
        outcome_treated[~data['complier']] = 0
        outcome_nottreated[~data['complier']] = 0

    if min(score) >= cutoff or max(score) <= cutoff:
        msg = 'cutoff out of score range'
        logger.warning(msg)
        return _error_result(method_name, setting_name, 0, msg)

    delta = setting_values['params'].get('neighborhood_oracle_delta', 0.02)
    t = time.time()
    neighborhood = (score > cutoff - delta) & (score < cutoff + delta)
    estimate = outcome_treated[neighborhood].mean() - outcome_nottreated[neighborhood].mean()
    took = time.time() - t

    return {
      "setting": setting_name,
      "method": method_name,
      "duration": took,
      "error": None,
      # coef
      "coef_conventional": estimate,
      "coef_robust": None,
      "coef_bias_corrected": None,
      # se
      "se_conventional": None,
      "se_robust": None,
      "se_bias_corrected": None,
      # ci
      "ci_0025_conventional": None,
      "ci_0975_conventional": None,
      "ci_0025_robust": None,
      "ci_0975_robust": None,
      "ci_0025_bias_corrected": None,
      "ci_0975_bias_corrected": None,
      # learner
      "rmse_mlg": None,
      "logloss_mlm": None
    }


def benchmark_oracle_kernel(
    method_name,
    setting_name,
    setting_values,
    data,
    cutoff
):
    t = time.time()
    score = data['score']

    fuzzy = setting_values['params']['fuzzy']
    if fuzzy:
        ite = data['outcome_treated'] - data['outcome_nottreated']
        ite = ite[data['complier']]
        score = score[data['complier']]
    else:
        ite = data['outcome_treated'] - data['outcome_nottreated']
        ite[~data['complier']] = 0

    if min(score) >= cutoff or max(score) <= cutoff:
        msg = 'cutoff out of score range'
        logger.warning(msg)
        return _error_result(method_name, setting_name, 0, msg)

    try:
        kernel_reg = KernelReg(endog=ite, exog=score, var_type='c', reg_type='ll')
        effect_at_cutoff, _ = kernel_reg.fit(np.array([cutoff]))
        took = time.time() - t
        logger.info('oracle kernel r**2: %s', kernel_reg.r_squared())
        return {
          "setting": setting_name,
          "method": method_name,
          "duration": took,
          "error": None,
          # coef
          "coef_conventional": effect_at_cutoff.item(),
          "coef_robust": None,
          "coef_bias_corrected": None,
          # se
          "se_conventional": None,
          "se_robust": None,
          "se_bias_corrected": None,
          # ci
          "ci_0025_conventional": None,
          "ci_0975_conventional": None,
          "ci_0025_robust": None,
          "ci_0975_robust": None,
          "ci_0025_bias_corrected": None,
          "ci_0975_bias_corrected": None,
          # learner
          "rmse_mlg": None,
          "logloss_mlm": None
        }
    except (np.linalg.LinAlgError, ZeroDivisionError) as e:
        took = time.time() - t
        logger.warning('cached linalg error in kernel oracle... continue')
        return _error_result(method_name, setting_name, took, e)


def benchmark_rdrobust(
    method_name,
    setting_name,
    setting_values,
    data,
    cutoff
):
    if min(data['score']) >= cutoff or max(data['score']) <= cutoff:
        msg = 'cutoff out of score range'
        logger.warning(msg)
        return _error_result(method_name, setting_name, 0, msg)

    t = time.time()
    fuzzy_rdrobust = data['treatment_actual'].astype("bool") if setting_values['params']['fuzzy'] else None
    try:
        res = rdrobust(
            y=data['outcome'],
            x=data['score'],
            fuzzy=fuzzy_rdrobust,
            covs=data['covariates'],
            c=cutoff
        )
        took = time.time() - t
        return {
          "setting": setting_name,
          "method": method_name,
          "duration": took,
          "error": None,
          # coef
          "coef_conventional": res.coef.loc["Conventional", "Coeff"],
          "coef_robust": res.coef.loc["Robust", "Coeff"],
          "coef_bias_corrected": res.coef.loc["Bias-Corrected", "Coeff"],
          # se
          "se_conventional": res.se.loc["Conventional", "Std. Err."],
          "se_robust": res.se.loc["Robust", "Std. Err."],
          "se_bias_corrected": res.se.loc["Bias-Corrected", "Std. Err."],
          # ci
          "ci_0025_conventional": res.ci.loc["Conventional", "CI Lower"],
          "ci_0975_conventional": res.ci.loc["Conventional", "CI Upper"],
          "ci_0025_robust": res.ci.loc["Robust", "CI Lower"],
          "ci_0975_robust": res.ci.loc["Robust", "CI Upper"],
          "ci_0025_bias_corrected": res.ci.loc["Bias-Corrected", "CI Lower"],
          "ci_0975_bias_corrected": res.ci.loc["Bias-Corrected", "CI Upper"],
          # learner
          "rmse_mlg": None,
          "logloss_mlm": None
        }
    except (np.linalg.LinAlgError, ZeroDivisionError) as e:
        took = time.time() - t
        logger.warning('cached linalg error in rd robust... continue')
        return _error_result(method_name, setting_name, took, e)


def benchmark_dml_method(
    method_name,
    method,
    setting_name,
    setting_values,
    dml_data,
    cutoff
):
    if min(dml_data.s) >= cutoff or max(dml_data.s) <= cutoff:
        msg = 'cutoff out of score range'
        logger.warning(msg)
        return _error_result(method_name, setting_name, 0, msg)
    try:
        t = time.time()
        rdflex_model = RDFlex(
            dml_data,
            ml_g=method['ml_g'],
            ml_m=method['ml_m'],
            n_rep=method['n_rep'],
            cutoff=cutoff,
            n_folds=5,
            # setting parameter for the method ;)
            fuzzy=setting_values['params']['fuzzy'],
            fs_kernel="triangular",
        )
        rdflex_model.fit(n_iterations=2)
        took = time.time() - t
        confint = rdflex_model.confint()
        return {
            "setting": setting_name,
            "method": method_name,
            "duration": took,
            "error": None,
            # coef
            "coef_conventional": rdflex_model.coef[0],
            "coef_robust": rdflex_model.coef[1],
            "coef_bias_corrected": rdflex_model.coef[2],
            # se
            "se_conventional": rdflex_model.se[0],
            "se_robust": rdflex_model.se[1],
            "se_bias_corrected": rdflex_model.se[2],
            # ci
            "ci_0025_conventional": confint.loc["Conventional", "2.5 %"],
            "ci_0975_conventional": confint.loc["Conventional", "97.5 %"],
            "ci_0025_robust": confint.loc["Robust", "2.5 %"],
            "ci_0975_robust": confint.loc["Robust", "97.5 %"],
            "ci_0025_bias_corrected": confint.loc["Bias-Corrected", "2.5 %"],
            "ci_0975_bias_corrected": confint.loc["Bias-Corrected", "97.5 %"],
            # learner
            "rmse_mlg_left": rdflex_model._nuisance_loss["ml_g"]['left'][0],
            "logloss_mlm_left": rdflex_model._nuisance_loss["ml_m"]['left'][0],
            "r2score_mlg_left": rdflex_model._r2test["ml_g"]['left'][0],
            "rmse_mlg_right": rdflex_model._nuisance_loss["ml_g"]['right'][0],
            "logloss_mlm_right": rdflex_model._nuisance_loss["ml_m"]['right'][0],
            "r2score_mlg_right": rdflex_model._r2test["ml_g"]['right'][0],
        }
    except (np.linalg.LinAlgError, ZeroDivisionError) as e:
        took = time.time() - t
        logger.warning('cached linalg error in dml... continue')
        return _error_result(method_name, setting_name, took, e)
    except ValueError as e:
        took = time.time() - t
        logger.warning('cached ValueError in dml... continue')

        try:
            dic = {"w": rdflex_model.w,
                   "h": rdflex_model.h,
                   "score": dml_data.s,
                   "y": dml_data.y,
                   "d": dml.d,
                   **rdflex_model.predictions}
            pd.DataFrame(dic).to_csv("error_data.csv")
            logger.warning("Errorous Object Saved")
        except Exception:
            logger.warning("An error occurred while saving the model.")

        return _error_result(method_name, setting_name, took, e)


def _error_result(method_name, setting_name, took, error):
    return {
        "setting": setting_name,
        "method": method_name,
        "duration": took,
        "error": str(error),
        # coef
        "coef_conventional": None,
        "coef_robust": None,
        "coef_bias_corrected": None,
        # se
        "se_conventional": None,
        "se_robust": None,
        "se_bias_corrected": None,
        # ci
        "ci_0025_conventional": None,
        "ci_0025_robust": None,
        "ci_0025_bias_corrected": None,
    }


def benchmark(data_generator, settings, methods, output, basename):
    writer = JsonResultsWriter(output, basename)
    writer.init()
    for setting_name, setting in settings.items():
        benchmark_setting(data_generator, setting_name, setting, methods, writer)


def fetch_json_results(path, basename, settings):
    files = [
        f for f in os.listdir(path)
        if f.startswith(basename) and f.endswith('.json')
    ]
    data_frames = []
    for file in files:
        with open(os.path.join(path, file), mode='r', encoding='utf-8') as json_file:
            data = json.load(json_file)
            data = {**data, **{f'setting_param_{k}': v for k, v in settings[data['setting']]['params'].items()}}
            data = {**data, **{f'setting_data_{k}': v for k, v in settings[data['setting']]['data'].items()}}
            data_frames.append(pd.DataFrame([data]))
    if not data_frames:
        return pd.DataFrame()
    return pd.concat(data_frames, ignore_index=True)
