from functools import partial
from itertools import chain
import numpy as np
import os
import torch
from torch import optim, distributions as dists
from torch.optim import lr_scheduler
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from pathlib import Path
import sys
file_path = Path(__file__)
project_path = file_path.parents[2]
sys.path.append(project_path.as_posix())
from Tests import local_utils  # noqa: E402
from cqc.utils.data import (list_of_dicts_transpose, gen_error)  # noqa: E402
from cqc.estimators import nonparamcdf as npcdf  # noqa: E402
from cqc.models.cqc_models import LinearGenerator, MLPGenerator, RFFGenerator  # noqa: E402
from cqc.estimators import grad_estimators as est  # noqa: E402
from cqc.estimators import sklearn_estimators as skest  # noqa: E402
from cqc.estimators import old_estimators as oldest  # noqa: E402


class MultivariateUniform(dists.Uniform):
    def __init__(self, low: float, high: float, d: int):
        super().__init__(low, high)
        self.d = d

    def sample(self, sample_shape=torch.Size()):
        return super().sample(torch.Size(sample_shape) + torch.Size([self.d]))

    def log_prob(self, value):
        return super().log_prob(value).sum(-1)


#######################
# # MARK: Data Set-up #
#######################
dim = 10
y_base_dist = dists.Normal(0, 1)
x_base_dist = MultivariateUniform(-1, 1, dim)

torch.manual_seed(123)
coef_vector = torch.randn(10)
# Normalise to get a gradient of sqrt{d} in steepest direction.
coef_vector = coef_vector/torch.linalg.norm(coef_vector)*np.sqrt(10.)
# Find gradient in steepest direction of length sqrt{d}
print(torch.sum(coef_vector**2)/torch.linalg.norm(coef_vector))


def propensity_freq(x, vec):
    return torch.clamp(torch.sigmoid((x@vec)).squeeze(-1), 0.1, 0.9)


gs_0_slope = (lambda x, slope, vec: (torch.sin(torch.pi*(x@vec))),
              lambda x, slope, vec: 1)
gs_1_slope = (lambda x, slope, vec: (torch.sin(torch.pi*(x@vec))+slope*x@vec),
              lambda x, slope, vec: 1)

save_path = project_path / "Results" / "SimulatedData" / "10DimLinearCQC" / "VarySlope"
sample_size = 500
n_repeats = 100
os.makedirs(save_path, exist_ok=True)
hyper_vals = torch.load(save_path / "hyperparameter_vals.pt")

save_file = f"results_wbias_n_{sample_size}_hpc_2.pt"

methods = ["ipw", "ipw_est", "dr", "dr_est", "dr_mlp", "dr_est_mlp", "dr_rff", "dr_est_rff"]
old_methods = ["OldDR", "OldDR_est", "CQTE", "CQTE_est", "Separate"]
diffs = {method: [] for method in chain(methods, old_methods, ["dr_fixed"])}
linear_args = {"optimiser": optim.Adam, "opt_args": {"lr": 1e-1}, "fit_type": "Split"}
rff_args = {"optimiser": optim.Adam, "opt_args": {"lr": 1e-1, "weight_decay": 1e-2}, "fit_type": "Split"}
mlp_args = {"optimiser": optim.Adam, "opt_args": {"lr": 1e-3}, "fit_type": "Split"}
training_args = {"niters": 1000, "snapshot_freq": 100}

# slopes
slopes = torch.linspace(0, 5, 6)
# sample size
for i, slope in enumerate(slopes):
    # ### Set-up funcs ####
    print(f"Slope: {slope}")
    for repeat in tqdm(range(n_repeats)):
        # Set up functions
        coef_vector = torch.randn(10)
        # Normalise to get a gradient of sqrt{d} in steepest direction.
        coef_vector = coef_vector/torch.linalg.norm(coef_vector)*np.sqrt(10.)
        gs_0 = [partial(func, slope=slope, vec=coef_vector) for func in gs_0_slope]
        gs_1 = [partial(func, slope=slope, vec=coef_vector) for func in gs_1_slope]
        propensity = partial(propensity_freq, vec=coef_vector)
        cdf_0, cdf_1, icdf_0, icdf_1, true_transform, density_0, density_1 = local_utils.get_all_obj(
            y_base_dist, gs_0, gs_1, x_base_dist, propensity)

        # ### Set up nuisance estimators ####
        # Give Kernels
        cdf_kernel = npcdf.kernel.KGauss(hyper_vals["CDF0"][i])
        # cdf_kernel = npcdf.kernel.KGauss(sigma2=0.01)
        internal_kernel = npcdf.kernel.KGauss(sigma2=0.5)
        cqte_kernel = npcdf.kernel.KGauss(sigma2=hyper_vals["CDF0"][i])
        # cqte_kernel = npcdf.kernel.KGauss(sigma2=0.01)
        # Give estimators
        est_propensity = skest.TorchModelWrapper(LogisticRegression, min=0.1, max=0.9)
        est_cdf = npcdf.kernel_cdf(kernel=cdf_kernel)
        est_quantile = npcdf.kernel_cdf(kernel=cdf_kernel, inverse_main=True)
        cqte_regressor = npcdf.kernel_regressor(cqte_kernel)
        true_cdf_0 = npcdf.exact_cdf(cdf_0)
        true_cdf_1 = npcdf.exact_cdf(cdf_1)

        # ### Set up estimators ####
        estimator: dict[str, est.GenericEstimator] = {}
        # New Methods
        cqc = LinearGenerator(dim)
        mlp_cqc = MLPGenerator(dim, [20, 20])
        rff_cqc = RFFGenerator(dim, 50, sigma=2.)
        estimator["ipw"] = est.CrossIPWEstimator(cqc, propensity, **linear_args)
        estimator["ipw_est"] = est.CrossIPWEstimator(cqc, est_propensity, **linear_args)
        estimator["dr"] = est.CrossDREstimator(cqc, propensity, cdf_0, cdf_1, **linear_args)
        estimator["dr_est"] = est.CrossDREstimator(cqc, est_propensity, est_cdf, **linear_args)
        # Log MLP approach
        estimator["dr_mlp"] = est.CrossDREstimator(mlp_cqc, propensity, cdf_0, cdf_1, **mlp_args)
        estimator["dr_est_mlp"] = est.CrossDREstimator(mlp_cqc, est_propensity, est_cdf, **mlp_args)
        estimator["dr_rff"] = est.CrossDREstimator(rff_cqc, propensity, cdf_0, cdf_1, **rff_args)
        estimator["dr_est_rff"] = est.CrossDREstimator(rff_cqc, est_propensity, est_cdf, **rff_args)
        # Add in fixed estimator
        estimator["dr_fixed"] = est.CrossDREstimator(cqc, propensity, cdf_0, cdf_1, fit_type="Split")

        # Add scheduler to each
        for key, method in enumerate(methods+["dr_fixed"]):
            estimator[method].set_scheduler(lr_scheduler.StepLR, step_size=100, gamma=0.8)

        # Do old methods
        estimator["OldDR"] = oldest.CrossDRInverseEstimator(internal_kernel, propensity, true_cdf_0, true_cdf_1,
                                                            fit_type="Split")
        estimator["OldDR_est"] = oldest.CrossDRInverseEstimator(internal_kernel, est_propensity, est_cdf,
                                                                fit_type="Split")
        estimator["CQTE"] = oldest.CrossCQTEEstimator(
            cqte_regressor, None, propensity, icdf_0, density_0, icdf_1, density_1,
            fit_type="Split", compatibility_mode=True)
        estimator["CQTE_est"] = oldest.CrossCQTEEstimator(
            cqte_regressor, None, est_propensity, est_quantile, density_0,
            density1_model=density_1, fit_type="Split", compatibility_mode=True)
        estimator["Separate"] = oldest.CrossSeparateEstimator(est_cdf, fit_type="Split")

        # Generate Data
        # ### Create data ####
        y, x, a = local_utils.uneven_data_gen(gs_0, gs_1, propensity, x_base_dist, y_base_dist, sample_size)
        # Train methods
        for method in methods:
            fail_fl = False
            for attempt in range(100):
                try:
                    if fail_fl:
                        if "mlp" in method:
                            new_cqc = MLPGenerator(dim, [20, 20])
                        else:
                            new_cqc = LinearGenerator(dim)
                        estimator[method].update_model(new_cqc)
                    estimator[method].fit(x, y, a, **training_args)
                    break
                except ValueError as e:
                    if "nan" in str(e).lower():
                        print("NaN error, retrying...")
                        fail_fl = True
                        continue
                    else:
                        raise e
        estimator["dr_fixed"].fit(x, y, a, niters=0)  # Just to set up the model

        # Train old methods
        prob_test_point = torch.rand(1)
        estimator["CQTE"].update_prob(prob_test_point)
        estimator["CQTE_est"].update_prob(prob_test_point)
        for method in old_methods:
            estimator[method].fit(x, y, a)

        # Run testing
        test_x0s = x_base_dist.sample((10000,))
        test_y0s = icdf_0(prob_test_point, test_x0s)

        # Do testing for new methods (not each state)
        for method in chain(methods, old_methods, ["dr_fixed"]):
            temp_diffs = {"g_error": gen_error((test_y0s, test_x0s), estimator[method].predict,
                                               true_transform, batch_size=1000).item(),
                          "slope": slope, "sample_size": sample_size, "coef_vector": coef_vector, "repeat": repeat}

            diffs[method].append(temp_diffs)
        # At the end of each repeat (not each method) save the results
        torch.save(diffs, save_path/save_file)
diffs = {k: list_of_dicts_transpose(v) for k, v in diffs.items()}
torch.save(diffs, save_path/save_file)
