from argparse import ArgumentParser
from contextlib import redirect_stdout
import datetime
from functools import cmp_to_key
import itertools
import json
import os
import pathlib
import pickle
import pprint
import socket
import sys

import git
import numpy as np
import pandas as pd
from ruamel.yaml import YAML
from sklearn.base import clone
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from skorch import NeuralNetClassifier, NeuralNetRegressor
from skorch.callbacks import EarlyStopping, EpochScoring, LRScheduler
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm

import base_learners
from weighting import PermutationWeighter
from evaluation import RankingEvaluator, do_counterfactual_evaluations
from utils import match_wo_replacement, DATASET_PATHSPEC

from typing import Optional


_ENCODING_STRATEGIES = ["onehot", "static_embed", "learned_embed"]
_OUTCOME_MODELS = ["rf", "nn", "pw-nn", "agent-nn", "dragon-nn", "rlearn-nn"]
_LEARNERS = ["s", "t", "r", "psm", "dragon"]
yaml = YAML(typ='safe')
yaml.default_flow_style = False


class RankingMetaLearner(object):
    def __init__(
        self,
        overall_model: Optional = None,
        propensity_model: Optional = None,
        outcome_model: Optional = None,
        encoding_strategy: Optional[str] = "onehot",
        covariate_col_prefix: Optional[str] = "x",
        treatment_col: Optional[str] = "t",
        outcome_col: Optional[str] = "d_obs",
        end_to_end: Optional[bool] = False, # unused for now
        counterfactual_col_prefix: Optional[str] = None,
        permutation_seed: Optional[int] = 137,
    ):
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.encoding_strategy = encoding_strategy
        if encoding_strategy not in _ENCODING_STRATEGIES:
            raise ValueError("Encoding strategy unknown.")
        self.covariate_col_prefix = covariate_col_prefix
        self.counterfactual_col_prefix = counterfactual_col_prefix
        self.permutation_seed = permutation_seed
        self.overall_model = overall_model
        self.propensity_model = propensity_model
        self.outcome_model = outcome_model
        self.preds = {}
        self.metrics = {}
        self.oh_encoder = None
        self.evaluator = None

    def get_or_create_learner(self, plan1, plan2):
        pass

    def attach_evaluator(self, evaluator, ground_truth_ranking):
        self.evaluator = evaluator
        self.true_ranking = ground_truth_ranking
        return self.evaluator

    def rank(self, df, plan_df, **kwargs):
        """
            Python's sorted() function is ascending by default. The cmp_to_key()
            function works as follows:

            The less than operator is defined as

            cmp_to_key(plan1, plan2) < 0 <=> ATE[plan1 -> plan2] < 0
            
            and vice versa for the greater-than operator.

            That is -- the plan on the RHS has a higher coding rate than
            the plan on the LHS. So we are simply ranking by the counterfactual means.
            If confounders are controlled-for this is equivalent
            to developing ranking *ascending in upcoding intensity*. This is equivalent
            to a ranking *descending in lambda.*

            By design in the simulated dataset; the optimal ranking in simulation  is therefore
            of the form [9, 8, 7, ..., 1, 0].
        """
        t = df[self.treatment_col]
        return sorted(
            t.unique(),
            key=cmp_to_key(
                lambda plan1, plan2: self.estimate_cate(
                    df,
                    plan_df,
                    plan1,
                    plan2,
                    **kwargs,
                ).mean()
            )
        )  # ascending order -- right = higher upcoding

    def fit(self, df):
        pass

    def estimate_cate(self, df, plan_df, plan1, plan2):
        pass

    def get_inference_features(self, df, plan1, plan2):
        pass

    def get_treatment_features(self, t_, plan_df):
        if self.encoding_strategy == "onehot":
            if self.oh_encoder is None:
                self.oh_encoder = OneHotEncoder(
                    categories='auto',
                    drop="first",
                    sparse_output=False
                )  # otherwise we get multicollinearity
                t_encoded = self.oh_encoder.fit_transform(t_.values.reshape(-1, 1))
            else:
                t_encoded = self.oh_encoder.transform(t_.values.reshape(-1, 1))
        elif self.encoding_strategy == "static_embed":
            t_encoded = plan_df.values[t_.values]
        else:  # learned_embed
            raise NotImplementedError()
        return t_encoded

    def get_feature_subset(self, df, plan1, plan2):
        t = df[self.treatment_col]
        y = df[self.outcome_col]
        plan_mask = (t == plan1) | (t == plan2)
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df.loc[plan_mask, covariate_cols]
        y_ = y[plan_mask]
        t_ = t[plan_mask]
        return X_, y_, t_, plan_mask

    def report_true_cate(self, df, plan1, plan2, verbose=True):
        true_cate = None
        if self.counterfactual_col_prefix is not None:
            true_cate = df[self.counterfactual_col_prefix + str(plan1)] - df[self.counterfactual_col_prefix + str(plan2)]
            if verbose:
                print("\tTrue ATE (synthetic data only):", true_cate.mean())
        return true_cate

    def get_saved_metrics(self):
        return self.metrics

    def predict_proba(self, *args, **kwargs):
        pass
        
    def get_factual_labels(self, df):
        # only possible on synth data
        t_ = df[self.treatment_col]
        y_cf = df[[f"d_{i}" for i in sorted(t_.unique())]]
        final_labels = y_cf.values[np.arange(len(df)), t_]
        return final_labels


class PropensityScoreMatcher(RankingMetaLearner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_dict = {}
        self.matching_weights = {}
    
    def fit(self, df, plan_df):
        # fit a propensity model and then read off P(d_i) at inference time
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df.loc[:, covariate_cols]
        t_ = df[self.treatment_col]
        print("Fitting propensity model...")
        for plan1, plan2 in tqdm(list(itertools.combinations(sorted(t_.unique()), 2))):
            t_mask = (t_ == plan1) | (t_ == plan2)
            t_subset = t_[t_mask]
            t_subset[t_subset == plan1] = 0
            t_subset[t_subset == plan2] = 1
            curr_model = clone(self.propensity_model)
            curr_model.fit(X_[t_mask], t_subset) # probably like a random forest
            self.model_dict[(plan1, plan2)] = curr_model

    def get_inference_features(self, df, plan_df, plan1, plan2): 
        X_, y_, t_, mask = self.get_feature_subset(df, plan1, plan2)
        xdict_1 = {"X_": X_[t_ == plan1], "y_": y_[t_ == plan1].values}
        xdict_2 = {"X_": X_[t_ == plan2], "y_": y_[t_ == plan2].values}
        return xdict_1, xdict_2

    def estimate_cate(self, df, plan_df, plan1, plan2, verbose=True, report_true_cate=True, save_metrics=True):
        xdict_1, xdict_2 = self.get_inference_features(df, plan_df, plan1, plan2)
        # match on

        plan_0 = min(plan1, plan2)
        plan_1 = max(plan1, plan2)
        preds1 = self.model_dict[(plan_0, plan_1)].predict_proba(xdict_1["X_"])[:, 1]
        preds2 = self.model_dict[(plan_0, plan_1)].predict_proba(xdict_2["X_"])[:, 1]
        
        weights1, weights2 = match_wo_replacement(preds1, preds2) 
        self.matching_weights[(plan1, plan2)] = (weights1, weights2)
        cate = xdict_1["y_"][weights1] - xdict_2["y_"][weights2]
        return cate
 
class RankingSLearner(RankingMetaLearner):
    def fit(self, df, plan_df):
        """
            Because of the underlying EconML SLearner implementation, we have to rewrite the fit
            function manually to accommodate multiple treatments
        """
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df.loc[:, covariate_cols]
        y_ = df[self.outcome_col]
        t_ = df[self.treatment_col]
        t_encoded = self.get_treatment_features(t_, plan_df)
        feats = np.concatenate([X_, t_encoded], axis=1)
        x_dict = {
            "X_": feats.astype(np.float32)
        }
        if self.propensity_model is not None:
            self.propensity_model.fit_weights(df, self.treatment_col, covariate_cols, self.permutation_seed)
            weights = self.propensity_model.estimate_weights(df, self.treatment_col, covariate_cols) # TODO: get rid of magic string 
            df["sample_weight"] = weights
            x_dict["sample_weight"] = weights / weights.mean() 
        self.overall_model.fit(x_dict, y_.astype(np.int_))

    def get_inference_features(self, df, plan_df, plan1, plan2):
        if self.oh_encoder is None and self.encoding_strategy == "onehot":
            raise ValueError(
                "For plan-wise comparisons, one-hot encoder has to be fitted already via `.fit()`.")
        X_, y_, t_, mask = self.get_feature_subset(df, plan1, plan2)
        if self.encoding_strategy == "onehot":
            t_encoded = np.zeros((len(X_), len(self.oh_encoder.categories_)))
            trow1, trow2 = self.oh_encoder.transform(
                np.array([[plan1], [plan2]]))
            t1 = np.tile(trow1, (len(X_), 1))
            t2 = np.tile(trow2, (len(X_), 1))
        elif self.encoding_strategy == "static_embed":
            t1 = np.tile(plan_df.iloc[plan1], (len(X_), 1))
            t2 = np.tile(plan_df.iloc[plan2], (len(X_), 1))
        else:
            raise NotImplementedError()
        x_dict1 = {"X_": np.concatenate([X_, t1], axis=1).astype(np.float32)}
        x_dict2 = {"X_": np.concatenate([X_, t2], axis=1).astype(np.float32)}
        if self.propensity_model is not None:
            print("Estimating sample weights...")
            covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
            sample1 = df[mask].copy()
            sample1[self.treatment_col] = plan1

            x_dict1["sample_weight"] = self.propensity_model.estimate_weights(sample1, self.treatment_col, covariate_cols)

            sample2 = df[mask].copy()
            sample2[self.treatment_col] = plan2
            x_dict2["sample_weight"] = self.propensity_model.estimate_weights(sample2, self.treatment_col, covariate_cols)
        return x_dict1, x_dict2

    def estimate_cate(self, df, plan_df, plan1, plan2, verbose=True, report_true_cate=True, save_metrics=True):
        if verbose:
            print("Estimating CATE,", plan1, "vs.", plan2)
        results = {}
        X1, X2 = self.get_inference_features(df, plan_df, plan1, plan2)
        t = df[self.treatment_col]
        plan_mask = (t == plan1) | (t == plan2)

        p1 = self.overall_model.predict_proba(X1)[:, 1] 
        p2 = self.overall_model.predict_proba(X2)[:, 1]
        if self.propensity_model is not None:
            p1 = p1 * X1["sample_weight"]
            p2 = p2 * X2["sample_weight"]
        cate = p1 - p2
        self.preds[(plan1, plan2)] = cate
        if verbose:
            print("\tATE:", cate.mean())
        results = {
            "CATE": cate,
            "ATE": cate.mean(),
        }
        true_cate = self.report_true_cate(df, plan1, plan2, verbose=verbose)
        if true_cate is not None and report_true_cate:
            pehe = np.sqrt(np.square(true_cate.values[plan_mask] - cate).mean())
            print("\tPEHE:", pehe)
            results["PEHE"] = pehe
        if save_metrics:
            self.metrics[(plan1, plan2)] = results
        return cate

    def predict_proba(self, df, plan_df, t=None):
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df[covariate_cols]
        if t is None:
            t_ind = df[self.treatment_col]
        else:
            t_ind = df[self.treatment_col].copy(deep=False) 
            t_ind.values.fill(t) 
        t_ = self.get_treatment_features(t_ind, plan_df)
        return self.overall_model.predict_proba({
            "X_": np.concatenate([X_.values.astype(np.float32), t_], axis=1).astype(np.float32)
        })[:, 1]

        
class RankingTLearner(RankingMetaLearner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.models = {}

    def fit(self, df, plan_df):
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df.loc[:, covariate_cols].values
        y_ = df[self.outcome_col]
        t_ = df[self.treatment_col]
        for treatment_level in t_.unique():
            new_model = clone(self.overall_model)
            if self.encoding_strategy == "static_embed":
                t_encoded = self.get_treatment_features(t_, plan_df) # deprecated
                feats = np.concatenate([X_, t_encoded], axis=1)
            else:
                feats = X_
            x_dict = {
                "X_": feats[t_ == treatment_level].astype(np.float32),
            } 
            new_model.fit(x_dict, y_[t_ == treatment_level].values.astype(np.int_))
            self.models[treatment_level] = new_model

    def get_inference_features(self, df, plan_df, plan1, plan2):
        X_, _, t_, _ = self.get_feature_subset(df, plan1, plan2)
        x_dict = {"X_": X_.values.astype(np.float32)}
        if self.propensity_model is not None:
            covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
            weights = self.propensity_model.estimate_weights(pd.concat([X_, t_], axis=1), self.treatment_col, covariate_cols)
            x_dict["sample_weight"] = weights
        return x_dict, x_dict.copy()

    def estimate_cate(self, df, plan_df, plan1, plan2, verbose=True, report_true_cate=True, save_metrics=True):
        if verbose:
            print("Estimating CATE,", plan1, "vs.", plan2)
        X1, X2 = self.get_inference_features(df, plan_df, plan1, plan2)
        p1 = self.models[plan1].predict_proba(X1)[:, 1]
        p2 = self.models[plan2].predict_proba(X2)[:, 1]
        if self.propensity_model is not None:
            p1 = p1 * X1["sample_weight"]
            p2 = p2 * X2["sample_weight"]
        cate = p1 - p2
        self.preds[(plan1, plan2)] = cate

        if verbose:
            print("\tATE:", cate.mean())
        results = {
            "CATE": cate,
            "ATE": cate.mean(),
        }
        true_cate = self.report_true_cate(df, plan1, plan2, verbose=verbose)
        if true_cate is not None and report_true_cate:
            t = df[self.treatment_col]
            pehe = np.sqrt(
                np.square(true_cate.values[(t == plan1) | (t == plan2)] - cate).mean())
            print("\tPEHE:", pehe)
            results["PEHE"] = pehe
        if save_metrics:
            self.metrics[(plan1, plan2)] = results
        return cate

    def predict_proba(self, df, plan_df, t=None):
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df[covariate_cols].reset_index(drop=True).values.astype(np.float32)
        if t is None:
            t_ = df[self.treatment_col]
        else:
            t_ = df[self.treatment_col].copy(deep=False)    # use copy() → without modifying the original DF
            t_.values.fill(t) 
        preds = np.stack([self.models[plan].predict_proba({"X_": X_})[:, 1] for plan in range(len(plan_df))], axis=1)     
 
        return preds[np.arange(len(preds)), t_]

class RankingDragonNet(RankingSLearner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.encoding_strategy = "onehot"

    def fit(self, df, plan_df):
        X_ = df.loc[:, [c for c in df.columns if c.startswith(self.covariate_col_prefix)]]
        y_ = df[self.outcome_col]
        t_ = df[self.treatment_col]
        x_dict = {
            "X_": X_.values.astype(np.float32),
            "T_": t_.values.astype(int),
        }
         
        self.overall_model.fit(x_dict, y_.astype(np.int_))

    def get_inference_features(self, df, plan_df, plan1, plan2):
        
        X_, y_, t_, mask = self.get_feature_subset(df, plan1, plan2)
        
        x_dict1 = {"X_": X_.values.astype(np.float32)}
        x_dict2 = {"X_": X_.values.astype(np.float32)}
         
        return x_dict1, x_dict2

    def estimate_cate(self, df, plan_df, plan1, plan2, verbose=True, report_true_cate=True, save_metrics=True):
        if verbose:
            print("Estimating CATE,", plan1, "vs.", plan2)
        X1, X2 = self.get_inference_features(df, plan_df, plan1, plan2) 
        with torch.no_grad():
            _, head_outputs = self.overall_model.module_(torch.from_numpy(X1["X_"]).float().to(self.overall_model.device))
        cate = (head_outputs[plan1][ :,1] - head_outputs[plan2][:, 1]).cpu().detach().numpy()

        self.preds[(plan1, plan2)] = cate

        if verbose:
            print("\tATE:", cate.mean())
        results = {
            "CATE": cate,
            "ATE": cate.mean(),
        }
        true_cate = self.report_true_cate(df, plan1, plan2, verbose=verbose)
        if true_cate is not None and report_true_cate:
            t = df[self.treatment_col]
            pehe = np.sqrt(
                np.square(true_cate.values[(t == plan1) | (t == plan2)] - cate).mean())
            print("\tPEHE:", pehe)
            results["PEHE"] = pehe
        if save_metrics:
            self.metrics[(plan1, plan2)] = results
        return cate

    def predict_proba(self, df, plan_df, t=None):
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df[covariate_cols].reset_index(drop=True).values.astype(np.float32)
        if t is None:
            t_ = df[self.treatment_col].values
        else:
            t_ = df[self.treatment_col].copy(deep=False).values     # use copy() → without modifying the original DF
            t_.fill(t)       
        with torch.no_grad():
            _, preds = self.overall_model.module_(torch.from_numpy(X_).float().to(self.overall_model.device))
            preds = torch.stack(preds, dim=-1)
        return preds[np.arange(len(preds)), 1, t_].cpu().detach().numpy()

class RankingRLearner(RankingMetaLearner):
    def __init__(self, *args, stage1_max_epochs=500, stage2_max_epochs=10, num_update_steps=10, **kwargs):
        super().__init__(*args, **kwargs)
        self.stage1_max_epochs = stage1_max_epochs
        self.stage2_max_epochs = stage2_max_epochs
        self.num_update_steps = num_update_steps

    def get_inference_features(self, df, plan_df, plan1, plan2):
        if self.oh_encoder is None and self.encoding_strategy == "onehot":
            raise ValueError(
                "For plan-wise comparisons, one-hot encoder has to be fitted already via `.fit()`.")
        X_, y_, t_, mask = self.get_feature_subset(df, plan1, plan2)
        if self.encoding_strategy == "onehot":
            t_encoded = np.zeros((len(X_), len(self.oh_encoder.categories_)))
            trow1, trow2 = self.oh_encoder.transform(
                np.array([[plan1], [plan2]]))
            t1 = np.tile(trow1, (len(X_), 1))
            t2 = np.tile(trow2, (len(X_), 1))
        elif self.encoding_strategy == "static_embed":
            t1 = np.tile(plan_df.iloc[plan1], (len(X_), 1))
            t2 = np.tile(plan_df.iloc[plan2], (len(X_), 1))
        else:
            raise NotImplementedError()
        x_dict1 = {"X_": X_.values.astype(np.float32), "T_": t1}
        x_dict2 = {"X_": X_.values.astype(np.float32), "T_": t2}
        return x_dict1, x_dict2


    def fit(self, df, plan_df):
        X_ = df.loc[:, [c for c in df.columns if c.startswith(self.covariate_col_prefix)]]
        y_ = df[self.outcome_col].values.astype(int)
        t_ = df[self.treatment_col] # maybe just import a SIN at this point
        t_encoded = self.get_treatment_features(t_, plan_df)
        x_dict = {
            "X_": X_.values.astype(np.float32),
            "T_": t_encoded.astype(np.float32),
        }

        print(f"Stage 1 -- fitting conditional mean outcome nuisance")
        self.outcome_model.fit(x_dict, y_, epochs=self.stage1_max_epochs)
        
        self.overall_model.initialize()
        assert hasattr(self.overall_model, "module_")
        self.overall_model.module_.attach_outcome_nuisance(self.outcome_model) # Yikes
        self.propensity_model.initialize()
        for i in range(self.stage2_max_epochs):
            print(f"Stage 2 -- fitting other nuisances [{i+1}/{self.stage2_max_epochs}]")
            # fit the propensity nuisance and the final outcome model in alternating fashion for n_iterations 
            self.overall_model.module_.attach_propensity_featurizer(self.propensity_model)
            for j in range(self.num_update_steps):
                self.overall_model.switch_optimizer("tau")
                self.overall_model.partial_fit(x_dict, y_.astype(np.float32), epochs=1)        
                
                self.overall_model.switch_optimizer("propensity")
                self.overall_model.partial_fit(x_dict, y_.astype(np.float32), epochs=1)        
 
            # fir the propensity feature model
            with torch.no_grad():   
                t_tensor = torch.from_numpy(t_encoded).float().to(self.overall_model.device)
                t_pred_ = self.overall_model.module_.propensity_nuisance(t_tensor)
            self.propensity_model.partial_fit(x_dict, t_pred_, epochs=1)

            if self.evaluator is not None: # This is probably best refactored as a callback in the future
                print("Post-epoch results")
                ranking = self.rank(df, plan_df, verbose=False, report_true_cate=False, save_metrics=False)
                print("Plan rankings (ascending order):", ranking)
                results = self.evaluator.evaluate(self.true_ranking, ranking)
                pprint.pprint(results)


    def estimate_cate(self, df, plan_df, plan1, plan2, verbose=True, report_true_cate=True, save_metrics=True):
        if verbose:
            print("Estimating CATE,", plan1, "vs.", plan2)
        X1, X2 = self.get_inference_features(df, plan_df, plan1, plan2)
        
        # g(X) (h(T) - h(T')) is the final cate -- we use notation from their paper here 
        with torch.no_grad(): # yeah, this one ain't as simple as .forward()
            g_x = self.overall_model.module_.covariate_mapper(torch.from_numpy(X1["X_"]).float().to(self.overall_model.device)) # equal to X2["X_"]
            h_t1 = self.overall_model.module_.propensity_nuisance(torch.from_numpy(X1["T_"]).float().to(self.overall_model.device))
            h_t2 = self.overall_model.module_.propensity_nuisance(torch.from_numpy(X2["T_"]).float().to(self.overall_model.device))
            cate = (g_x * (h_t1 - h_t2)).sum(dim=-1).cpu().detach().numpy()
 
        self.preds[(plan1, plan2)] = cate

        if verbose:
            print("\tATE:", cate.mean())
        results = {
            "CATE": cate,
            "ATE": cate.mean(),
        }
        true_cate = self.report_true_cate(df, plan1, plan2, verbose=verbose)
        if true_cate is not None and report_true_cate:
            t = df[self.treatment_col]
            pehe = np.sqrt(
                np.square(true_cate.values[(t == plan1) | (t == plan2)] - cate).mean())
            print("\tPEHE:", pehe)
            results["PEHE"] = pehe
        if save_metrics:
            self.metrics[(plan1, plan2)] = results
        return cate

    def predict_proba(self, df, plan_df, t=None):
        covariate_cols = [c for c in df.columns if c.startswith(self.covariate_col_prefix)]
        X_ = df[covariate_cols].reset_index(drop=True).values.astype(np.float32)

        if t is None:
            t_ = df[self.treatment_col]
        else:
            t_ = df[self.treatment_col].copy(deep=False)     # use copy() → without modifying the original DF
            t_.values.fill(t)
        t_encoded = self.get_treatment_features(t_, plan_df)
        with torch.no_grad():
            g_x = self.overall_model.module_.covariate_mapper(torch.from_numpy(X_).float().to(self.overall_model.device))
            h_t = self.overall_model.module_.propensity_nuisance(torch.from_numpy(t_encoded).float().to(self.overall_model.device))
        
        return (g_x * h_t).sum(dim=-1).cpu().detach().numpy()

def get_dataset(dataset_name): # TODO: remove this in the future
    with open(DATASET_PATHSPEC, "r") as f:
        dataset_cfg = yaml.load(f)
    # "./analytic/synthetic/synthetic_uniform.csv"
    path = dataset_cfg[dataset_name]["data"]
    embed_path = dataset_cfg[dataset_name]["plans"]
    data_config_file = dataset_cfg[dataset_name].get("config", None)
    df = pd.read_csv(path, index_col=0, low_memory=False)
    plan_df = None # pd.read_csv(embed_path, index_col=0, low_memory=False) # DEPRECATED
    return df, plan_df, data_config_file


def get_true_ranking(data_config_file):
    with open(data_config_file, "r") as f:
        data_cfg = yaml.load(f)
        plans = data_cfg["plans"]
    # descending order by upcoding parameter = ascending order by CATE
    return np.array(plans) # [x for x, _ in sorted(enumerate(plans), key=lambda pair: -pair[1])]


def get_learner_class(learner, agentic=False):
    if learner == "s":
        if agentic:
            return AgenticTreatmentSLearner
        else:
            return RankingSLearner
    elif learner == "t":
        return RankingTLearner
    elif learner == "r":
        return RankingRLearner
    elif learner == "psm":
        return PropensityScoreMatcher
    elif learner == "dragon":
        return RankingDragonNet 
    else:
        raise NotImplementedError()

def get_callbacks(): # callback config?
    callbacks = [
        ('early_stopping', EarlyStopping(patience=10, threshold=1e-6)),
        ('reduce_lr_on_plateau', LRScheduler(policy=ReduceLROnPlateau, patience=5, factor=0.1))
    ]
    return callbacks

def get_propensity_model(model_cfg):
    if model_cfg is None:
        return None
    else:
        if model_cfg["model"] == "rf":
            model_class = RandomForestClassifier
        elif model_cfg.get("features", False):
            model_class = base_learners.FeatureMapper
        elif model_cfg["model"] == "nn":
            model_class = base_learners.BaseNN
        else:
            raise NotImplementedError()

        if model_cfg.get("binary", False):
            return model_class(**model_cfg.get("kwargs", {}))
        elif model_cfg.get("features", False):
            return NeuralNetRegressor(
                model_class,
                **model_cfg.get("kwargs", {})
            ) # for generalized R-learner  
        else:
            return PermutationWeighter(
                model_class,
                #seed=model_cfg["seed"],
                **model_cfg.get("kwargs", {}),
            )


def get_outcome_model(model, seed, module_class=None, **module_kwargs):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if model == "rf":
        return RandomForestRegressor(random_state=seed)
    elif model == "nn":
        return NeuralNetClassifier(
            getattr(base_learners, module_class),
            verbose=2,
            device=device,
            iterator_train__shuffle=True,
            callbacks=get_callbacks(),
            **module_kwargs,
        )
    elif model == "pw-nn":
        return base_learners.SampleWeightedClassifier(
            getattr(base_learners, module_class),
            verbose=2,
            device=device,
            iterator_train__shuffle=True,
            callbacks=get_callbacks(),
            **module_kwargs,
        )
    elif model == "agent-nn":
        assert module_class in ["AgenticNN", "AgenticAlternatingNN"]
        return getattr(base_learners, module_class)(
            base_learners.AgenticTreatmentWrapper,
            verbose=2,
            device=device,
            iterator_train__shuffle=True,
            callbacks=get_callbacks(),
            **module_kwargs,
        )
    elif model == "dragon-nn": 
        return base_learners.DragonNetWrapper(
            base_learners.DragonNet,
            verbose=2,
            device=device,
            iterator_train__shuffle=True,
            callbacks=get_callbacks(),
            **module_kwargs,
        )
    elif model == "rlearn-nn":
        return base_learners.RLearnerNN(
            base_learners.RLearnerWrapper,
            verbose=2,
            device=device,
            iterator_train__shuffle=True,
            callbacks=get_callbacks(),
            **module_kwargs,
        )
    else:
        raise NotImplementedError()


def save_model(save_path, meta_learner):
    with open(save_path, "wb") as f:
        pickle.dump(meta_learner, f, protocol=pickle.HIGHEST_PROTOCOL)
    print("Saved to", save_path)


def save_results(result_path, result_dict, ranking, true_ranking=None):
    result_dict["rank_pred"] = list(map(int, ranking))
    if true_ranking is not None:
        result_dict["rank_true"] = list(map(int, true_ranking))
    with open(result_path, "w") as f:
        json.dump(result_dict, f, sort_keys=True, indent=4)
    print("Saved results to", result_path)


def save_config(config_path, cfg):
    repo = git.Repo(search_parent_directories=True)
    cfg["run"] = {
        "hash": repo.head.object.hexsha,
        "date": str(datetime.datetime.now()),
        "hostname": socket.gethostname(),
    }
    if "SLURM_JOB_ID" in os.environ:
        cfg["run"]["slurm_jobid"] = os.environ["SLURM_JOB_ID"]
    with open(config_path, "w") as yf:
        yaml.dump(cfg, yf)
    print("Saved config to", config_path)


def save_bootstrap_results(bs_path, bs_result_df, bs_rankings):
    result_path = os.path.join(bs_path, "bootstrap_results.csv")
    bs_result_df.to_csv(result_path)
    print("Saved bootstrap metrics to", result_path)

    arr_path = os.path.join(bs_path, "rankings.npy")
    with open(arr_path, 'wb') as f:
        np.save(f, bs_rankings)
    print("Saved rankings array to", arr_path)

def validate_cfg(cfg):

    learner = cfg["model"]["learner"]
    assert learner in _LEARNERS

    outcome_model = cfg["model"]["outcome_model"]
    assert outcome_model in _OUTCOME_MODELS

    encoding_strategy = cfg["model"]["encoding_strategy"]
    assert encoding_strategy in _ENCODING_STRATEGIES

def get_args():
    psr = ArgumentParser()
    psr.add_argument("--name", required=True, type=str)
    psr.add_argument("--config", required=True, type=str)
    psr.add_argument("--bootstrap", default=0, type=int)
    psr.add_argument("--evaluator-config", type=str,
                     default="./config/default_evaluator.yml")
    psr.add_argument("--overwrite", action="store_true")
    psr.add_argument("--splits-to-fit", nargs='+', type=int)
    psr.add_argument("--dataset", type=str)
    args = psr.parse_args()

    return args

def setup_metalearner(cfg, df, plan_df):
    print("Loading meta-learner...")
    learner_class = get_learner_class(
        cfg["model"]["learner"],
        agentic=cfg["model"].get("agentic", False)
    )
    print("Initializing outcome model...")
    module_class = cfg["model"].get("module_class", None)

    outcome_model = None
    if cfg["model"]["learner"] not in ["psm", "kom"]:
        outcome_model = get_outcome_model(
            cfg["model"]["outcome_model"],
            cfg["model"]["seed"],
            module_class=module_class,
            **cfg["model"].get("kwargs", {}),
        )
    propensity_model = get_propensity_model(cfg["model"].get("propensity_model", None))
    if cfg["model"]["learner"] == "r":
        out_n_cfg = cfg["model"]["outcome_nuisance"]
        outcome_nuisance = get_outcome_model(
            out_n_cfg["model"],
            out_n_cfg["seed"],
            module_class=out_n_cfg["module_class"],
            **out_n_cfg.get("kwargs", {}),
        )
    else:
        outcome_nuisance = None    
    meta_learner = learner_class(
        overall_model=outcome_model,
        propensity_model=propensity_model,
        outcome_model=outcome_nuisance,
        encoding_strategy=cfg["model"]["encoding_strategy"],
        counterfactual_col_prefix=cfg["dataset"].get("cf_prefix", None),
        end_to_end=(module_class == "AgenticNN"),
        **cfg["model"].get("learner_kwargs", {})
    )
    return meta_learner


if __name__ == '__main__':
    args = get_args()
    with open(args.config, "r") as f:
        cfg = yaml.load(f)
    validate_cfg(cfg)

    save_dir = os.path.join("./estimators", args.name) #cfg["name"])
    if os.path.isdir(save_dir) and not args.overwrite:
        raise ValueError(f"{save_dir} exists. Exiting.")

    save_path = os.path.join(save_dir, "model_{}.pkl")
    result_path = os.path.join(save_dir, "results.csv")
    ranking_path = os.path.join(save_dir, "rankings.csv")
    config_path = os.path.join(save_dir, "config.yml")

    dataset_name = cfg["dataset"]["name"] if args.dataset is None else args.dataset
    if args.dataset is not None:
        print("Overriding dataset specification -- using dataset:", args.dataset)

    df, plan_df, data_config_file = get_dataset(dataset_name)
    if cfg["dataset"].get("normalize_agent_features", False):
        plan_df = (plan_df - plan_df.mean()) / plan_df.std()

    print("# features:", len([c for c in df.columns if c.startswith("x")]))
    #print("# agent features:", len(plan_df.columns))

    all_metrics = []
    all_rankings = []

    with open(args.evaluator_config, "r") as f: # we probably don't need to do I/O every time -- refactor
        eval_cfg = yaml.load(f)
    evaluator = RankingEvaluator.from_config(eval_cfg)

    if dataset_name.startswith("synth") or dataset_name.startswith("toy"):
        true_ranking = get_true_ranking(data_config_file)
    else:
        true_ranking = None

    if "split" in df.columns:
        splits = df["split"].unique()
    else:
        splits = [0]
    for i, split in enumerate(tqdm(splits)):
        torch.manual_seed(i)
        if args.splits_to_fit is not None:
            if i not in args.splits_to_fit:
                print("Skipping split", split)  
                continue
        print("Training on split", split)
        if os.path.isfile(save_path.format(i)) and not args.overwrite:
            print(f"Found model #{i}. Skipping.")
            continue
        if "split" in df.columns:
            df_subset = df[df["split"] == split]
        else:
            df_subset = df # eww. hack
        meta_learner = setup_metalearner(cfg, df_subset, plan_df)
        if dataset_name.startswith("synth"):
            _ = meta_learner.attach_evaluator(evaluator, true_ranking)
        print("Fitting meta-learner...")
        dev_df, test_df = train_test_split(df_subset, test_size=cfg["split"]["size"], random_state=cfg["split"]["seed"])
        meta_learner.fit(dev_df.reset_index(drop=True), plan_df)
        print("Computing rankings...")
        ranking = meta_learner.rank(test_df.reset_index(drop=True), plan_df, verbose=False, save_metrics=False)

        print("Plan rankings (ascending order):", ranking)
        print("Ground truth lambdas:", true_ranking)

        if true_ranking is not None:
            metrics = evaluator.evaluate(true_ranking, ranking, save_results=False)

            y_score = meta_learner.predict_proba(test_df, plan_df)
            y_true = meta_learner.get_factual_labels(test_df)

            try:
                print("Conducting observational evaluation...")
                metrics |= evaluator.evaluate_predictions(y_true, y_score)
        
                print("Conducting counterfactual evaluation...")
                cf_dict = do_counterfactual_evaluations(meta_learner, evaluator, test_df, plan_df) 
                metrics |= cf_dict
            except TypeError:
                print("Comparison failed when evaluating predictions. This is fine for estimators that don't output individual-level estimates.")
 
            print("Evaluation results:")
            pprint.pprint(metrics)
            all_metrics.append(metrics)
            

        print("Saving model...")
        pathlib.Path(save_dir).mkdir(exist_ok=True)
        save_model(save_path.format(i), meta_learner)
        
        all_rankings.append(ranking)
        #save_results(result_path, results, ranking, true_ranking=true_ranking)
    print("Saving config...")

    save_config(config_path, cfg)
    metric_df = pd.DataFrame(all_metrics)
    ranking_arr = np.array(all_rankings)

    if len(metric_df):
        with pd.option_context('display.max_columns', 20):
            print(metric_df.describe(percentiles=[0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.975]))
        metric_df.to_csv(result_path)
    np.savetxt(ranking_path, ranking_arr, delimiter=",")


