import argparse
import sys

import numpy as np
import pandas as pd
import yaml
from folktables import (
    ACSDataSource,
    ACSEmployment,
    ACSIncome,
    ACSMobility,
    ACSPublicCoverage,
    ACSTravelTime,
    BasicProblem,
    adult_filter,
    travel_time_filter,
)
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, PolynomialFeatures

sys.path.append("../")
from models.framework import PostDataset, PreDataset, Simulation, compare_models
from models.performativity_models import *

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default="config.yaml",
        help="file containing configuration parameters",
    )
    parser.add_argument(
        "--track", action="store_true", help="flag for tracking results with wandb"
    )
    parser.add_argument(
        "--threshold_override",
        type=float,
        help="flag for overriding threshold parameter, for sweeps",
    )
    parser.add_argument(
        "--lambda_override",
        type=float,
        help="override laplace noise lambda for sweeps"
    )
    parser.add_argument(
        "--random_seed",
        type=int,
        default=0,
        help="random seed for train test splitting"
    )
    parser.add_argument(
        "--featurization",
        type=str,
        default=None,
        help="whether there should be overparameterization with more features"
    )
    parser.add_argument(
        "--shift",
        type=str,
        default=None,
        help="type of distribution shift"
    )
    parser.add_argument(
        "--train_frac",
        type=float,
        default=0.8,
        help="Fraction of data to use for training"
    )
    parser.add_argument(
        "--test_noise",
        type=float,
        default=.5,
        help="Noise level for noisy test predictor fitting"
    )
    args = parser.parse_args()
    with open(args.config, "r") as f:
        config = yaml.safe_load(f)
    if args.threshold_override:
        config["performativity"]["params"]["threshold"] = args.threshold_override
    if args.lambda_override:
        config["performativity"]["params"]["lambda"] = args.lambda_override
    if args.random_seed:
        config["random_seed"] = args.random_seed
    if args.featurization:
        config["data"]["featurization"] = args.featurization
    if args.shift == 'train_accurate':
        config["distribution_shift"]["train_predictor"] = "accurate"
        config["distribution_shift"]["test_predictor"] = "random"
    if args.shift == 'test_accurate':
        config["distribution_shift"]["train_predictor"] = "random"
        config["distribution_shift"]["test_predictor"] = "accurate"
    if args.train_frac:
        config['data']['train_frac'] = args.train_frac
    if args.test_noise:
        config['test_noise'] = args.test_noise


    if args.track:
        import wandb

        wandb.init(project="test_performativity", config=config)

    ACSIncomeCont = BasicProblem(
        features=[
            'AGEP',
            'COW',
            'SCHL',
            'MAR',
            'OCCP',
            'POBP',
            'RELP',
            'WKHP',
            'SEX',
            'RAC1P',
        ],
        target='PINCP',
        target_transform=None,
        group='RAC1P',
        preprocess=adult_filter,
        postprocess=lambda x: np.nan_to_num(x, -1),
    )
    ACSTravelTimeCont = BasicProblem(
        features=[
            'AGEP',
            'SCHL',
            'MAR',
            'SEX',
            'DIS',
            'ESP',
            'MIG',
            'RELP',
            'RAC1P',
            'PUMA',
            'ST',
            'CIT',
            'OCCP',
            'JWTR',
            'POWPUMA',
            'POVPIP',
        ],
        target="JWMNP",
        target_transform=None,
        group='RAC1P',
        preprocess=travel_time_filter,
        postprocess=lambda x: np.nan_to_num(x, -1),
    )

    task_dict = {
        "employment": ACSEmployment,
        "income": ACSIncome,
        "mobility": ACSMobility,
        "public_coverage": ACSPublicCoverage,
        "income_cont": ACSIncomeCont,
        "travel_time": ACSTravelTime,
        "travel_time_cont": ACSTravelTimeCont,
    }
    model_dict = {
        "logisticregression": LogisticRegression,
        "randomforest": RandomForestClassifier,
        "linearregression": LinearRegression,
        "gb": GradientBoostingClassifier,
    }
    performativity_dict = {
        "self_fulfilling": self_fulfilling,
        "linear": linear,
        "laplace": laplace,
    }
    performativity_model = performativity_dict[config["performativity"]["name"]]

    # Model Setup
    rmse = False
    if config["data"]["task"] == "income_cont":
        rmse = True

    deployed_model_train = model_dict[config["model"]["deployed_model"]]
    deployed_model_test = model_dict[config["model"]["deployed_model"]]
    evaluated_model = model_dict[config["model"]["evaluated_model"]]
    performativity_params = config["performativity"]["params"]

    # Data Setup
    def setup_folktables(config):
        data_source = ACSDataSource(
            survey_year=str(config["data"]["year"]), horizon="1-Year", survey="person"
        )
        acs_data = data_source.get_data(states=[config["data"]["state"]], download=True)

        acs_task = task_dict[config["data"]["task"]]

        # Create PreDataset
        features, label, fairness_group = acs_task.df_to_numpy(acs_data)

        predataset = PreDataset(
            features=features,
            outcomes=label,
        )
        return predataset

    def setup_kaggle():
        datapath = "TO BE FILLED IN"
        kaggledf = pd.read_csv(datapath)
        outcomes = kaggledf["SeriousDlqin2yrs"]
        features = kaggledf.drop(kaggledf.columns[0:2], axis=1).fillna(-1).to_numpy()

        predataset = PreDataset(
            features=features,
            outcomes=outcomes,
        )
        return predataset


    if config["data"]["source"] == 'folktables':
        predataset = setup_folktables(config)
    elif config["data"]["source"] == 'kaggle':
        predataset = setup_kaggle()


    if config["data"]["featurization"] == "polynomial":
        model_train = make_pipeline(PolynomialFeatures(), StandardScaler(), deployed_model_train())
        model_test = make_pipeline(PolynomialFeatures(), StandardScaler(), deployed_model_test())
    elif config["model"]["deployed_model_params"]:
        model_train = make_pipeline(StandardScaler(), deployed_model_train(**config["model"]["deployed_model_params"]))
        model_test = make_pipeline(StandardScaler(), deployed_model_test(**config["model"]["deployed_model_params"]))
    else:
        model_train = make_pipeline(StandardScaler(), deployed_model_train())
        model_test = make_pipeline(StandardScaler(), deployed_model_test())

    train_type = config["distribution_shift"]["train_predictor"]
    if train_type == "accurate":
        model_train.fit(predataset.features, predataset.outcomes)
    elif train_type == "random":
        model_train.fit(
            predataset.features, np.random.choice(predataset.outcomes, len(predataset.features))
        )
    test_type = config["distribution_shift"]["test_predictor"]
    if test_type == "accurate":
        model_test.fit(predataset.features, predataset.outcomes)
    elif test_type == "random":
        model_test.fit(
            predataset.features, np.random.choice(predataset.outcomes, len(predataset.features))
        )
    elif test_type == 'noisy':
        if rmse:
            #new_outcomes = predataset.outcomes + config["test_noise"] * np.random.choice(predataset.outcomes, len(predataset.features))
            new_outcomes = predataset.outcomes * (1 + (1 - 2*np.random.random()) * config["test_noise"])
        else:
            flip = np.random.binomial(1, config["test_noise"], size=len(predataset.outcomes))
            new_outcomes = np.where(flip == 1, np.logical_not(predataset.outcomes), predataset.outcomes)
        model_test.fit(predataset.features, new_outcomes)

    simulation_train = Simulation(
        name="folktables_exp",
        predataset=predataset,
        predictor=model_train,
        params=performativity_params,
        step_performativity=performativity_model,
        graph=None,
    )

    simulation_test = Simulation(
        name="folktables_exp_test",
        predataset=predataset,
        predictor=model_test,
        params=performativity_params,
        step_performativity=performativity_model,
        graph=None,
    )

    # Run one step of performativity and collect PostDataset
    train_dataset = simulation_train.run()
    test_dataset = simulation_test.run()
    train_dataset, valid_dataset = train_dataset.split(test_size=1 - config['data']['train_frac'],
                                                       random_state=config["random_seed"])

    # Compare models that have access to yhat, and those that don't
    if config["model"]["evaluated_model_params"]:
        model_without_yhat = make_pipeline(StandardScaler(), evaluated_model(**config["model"]["evaluated_model_params"]))
        model_with_yhat = make_pipeline(StandardScaler(), evaluated_model(**config["model"]["evaluated_model_params"]))

    if config["data"]["g_featurization"] == 'polynomial':
        model_without_yhat = make_pipeline(PolynomialFeatures(), StandardScaler(), evaluated_model())
        model_with_yhat = make_pipeline(PolynomialFeatures(), StandardScaler(), evaluated_model())
    else:
        model_without_yhat = make_pipeline(StandardScaler(), evaluated_model())
        model_with_yhat = make_pipeline(StandardScaler(), evaluated_model())

    model_without_yhat.fit(train_dataset.features, train_dataset.outcomes)

    score1_train, score2_train, rmse1_train, rmse2_train = compare_models(
        model_without_yhat,
        model_with_yhat,
        train_dataset.features,
        train_dataset.get_features_with_yhat(),
        train_dataset.outcomes,
        train_dataset.features,
        train_dataset.get_features_with_yhat(),
        train_dataset.outcomes,
        rmse=rmse,
    )

    score1_valid, score2_valid, rmse1_valid, rmse2_valid = compare_models(
        model_without_yhat,
        model_with_yhat,
        train_dataset.features,
        train_dataset.get_features_with_yhat(),
        train_dataset.outcomes,
        valid_dataset.features,
        valid_dataset.get_features_with_yhat(),
        valid_dataset.outcomes,
        rmse=rmse,
    )

    score1, score2, rmse1, rmse2 = compare_models(
        model_without_yhat,
        model_with_yhat,
        train_dataset.features,
        train_dataset.get_features_with_yhat(),
        train_dataset.outcomes,
        test_dataset.features,
        test_dataset.get_features_with_yhat(),
        test_dataset.outcomes,
        rmse=rmse,
    )

    print("Train accuracy (no distribution shift)")
    print("Accuracy of model without yhat: ", score1_train)
    print("Accuracy of model with yhat: ", score2_train)
    print("RMSE of model without yhat: ", rmse1_train)
    print("RMSE of model with yhat: ", rmse2_train, "\n")


    print("Validation accuracy (no distribution shift)")
    print("Accuracy of model without yhat: ", score1_valid)
    print("Accuracy of model with yhat: ", score2_valid)
    print("RMSE of model without yhat: ", rmse1_valid)
    print("RMSE of model with yhat: ", rmse2_valid, "\n")

    print("Test accuracy (distribution shift where a new predictor has been deployed)")
    print("Accuracy of model without yhat: ", score1)
    print("Accuracy of model with yhat: ", score2)
    print("RMSE of model without yhat: ", rmse1)
    print("RMSE of model with yhat: ", rmse2, "\n")

    if args.track:
        wandb.log(
            {
                "test_acc_without_yhat": score1,
                "test_acc_with_yhat": score2,
                "test_acc_diff": score2 - score1,
                "valid_acc_without_yhat": score1_valid,
                "valid_acc_with_yhat": score2_valid,
                "valid_acc_diff": score2_valid - score1_valid,
                "train_acc_without_yhat": score1_train,
                "train_acc_with_yhat": score2_train,
                "train_acc_diff": score2_train - score1_train,
                "test_rmse_without_yhat": rmse1,
                "test_rmse_with_yhat": rmse2,
                "valid_rmse_without_yhat": rmse1_valid,
                "valid_rmse_with_yhat": rmse2_valid,
                "train_rmse_without_yhat": rmse1_train,
                "train_rmse_with_yhat": rmse2_train,
            }
        )
