from abc import ABC, abstractmethod
import copy
from sklearn.model_selection import train_test_split
import torch
from typing import Literal
from .nonparamcdf import InverseLearner, DRInverseLearner, IPWInverseLearner, CQTE_Learner, SeparateLearner
from ..utils import torch_kernel as kernel


class Estimator(ABC):
    @abstractmethod
    def __init__(self, model: InverseLearner):
        super().__init__()
        self.model = model

    @abstractmethod
    def fit(x: torch.Tensor, y: torch.Tensor, a: torch.Tensor):
        """
        Fit the estimator using the data.
        """
        pass

    @abstractmethod
    def predict(y0: torch.Tensor, x: torch.Tensor):
        """
        Predict the outcome using the fitted estimator.
        """
        pass


class PseudoOutcomeEstimator(Estimator):
    @abstractmethod
    def __init__(self, model: InverseLearner, **kwargs):
        super().__init__(model, *kwargs)
        if not hasattr(self, "nuisance_models"):
            self.nuisance_models = {}

    @abstractmethod
    def fit_nuisance(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor):
        """
        Fit the nuisance models using the data.
        """
        pass

    def fit(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, seed=None):
        if seed is None:
            seed = torch.randint(0, 1000000, (1,)).item()
        self.fit_seed = seed
        torch.manual_seed(seed)
        # Fit nuisance models
        self.fit_nuisance(x, y, a)
        x0 = x[a == 0]
        x1 = x[a == 1]
        y0 = y[a == 0]
        y1 = y[a == 1]
        self.model.fit(y0, x0, y1, x1)

    def predict(self, y0: torch.Tensor, x: torch.Tensor, **kwargs):
        return self.model.predict(y0, x, **kwargs)


class IPWInverseEstimator(PseudoOutcomeEstimator):
    def __init__(self, kernel: kernel.Kernel, propensity_model, **kwargs):
        self.propensity_model = copy.deepcopy(propensity_model)
        self.nuisance_models = {"Propensity": self.propensity_model}
        model = IPWInverseLearner(kernel, self.propensity_model)
        super().__init__(model=model, **kwargs)

    def fit_nuisance(self, x, y, a):
        # Fit model if it can be fit and is not already fitted
        if hasattr(self.propensity_model, "fit"):
            self.propensity_model.fit(x, a)


class DRInverseEstimator(PseudoOutcomeEstimator):
    def __init__(self, kernel: kernel.Kernel, propensity_model, cdf0_model, cdf1_model=None, **kwargs):
        self.propensity_model = copy.deepcopy(propensity_model)
        self.cdf0_model = copy.deepcopy(cdf0_model)
        cdf1_model = cdf1_model if cdf1_model is not None else cdf0_model
        self.cdf1_model = copy.deepcopy(cdf1_model)
        self.nuisance_models = {"propensity": self.propensity_model, "cdf_0": self.cdf0_model, "cdf1": self.cdf1_model}
        model = DRInverseLearner(kernel, self.cdf0_model, self.cdf1_model, self.propensity_model)
        super().__init__(model=model, **kwargs)

    def fit_nuisance(self, x, y, a):
        # Fit model if it can be fit and is not already fitted
        if hasattr(self.propensity_model, "fit"):
            self.propensity_model.fit(x, a)
        if hasattr(self.cdf0_model, "fit"):
            self.cdf0_model.fit(x[a == 0], y[a == 0])
            self.cdf1_model.fit(x[a == 1], y[a == 1])


class SeparateEstimator(PseudoOutcomeEstimator):
    def __init__(self, cdf0_model, cdf1_model=None, **kwargs):
        self.cdf0_model = copy.deepcopy(cdf0_model)
        cdf1_model = cdf1_model if cdf1_model is not None else cdf0_model
        self.cdf1_model = copy.deepcopy(cdf1_model)
        self.nuisance_models = {"cdf_0": self.cdf0_model, "cdf1": self.cdf1_model}
        model = SeparateLearner(self.cdf0_model, self.cdf1_model)
        super().__init__(model=model, **kwargs)

    def fit_nuisance(self, x, y, a):
        # Fit model if it can be fit and is not already fitted
        if hasattr(self.cdf0_model, "fit"):
            self.cdf0_model.fit(x[a == 0], y[a == 0])
            self.cdf1_model.fit(x[a == 1], y[a == 1])


class CrossInverseEstimator(PseudoOutcomeEstimator):
    def __init__(self, *args, fit_type, **kwargs):
        super().__init__(*args, **kwargs)
        self.fit_type = fit_type
        self.initial_model = copy.deepcopy(self.model)
        self.inital_nuisance_models = copy.deepcopy(self.nuisance_models)

    def split_fit(self, x_s0, y_s0, a_s0, x_s1, y_s1, a_s1, *args, **kwargs):
        self.fit_nuisance(x_s0, y_s0, a_s0)

        x0_s1 = x_s1[a_s1 == 0]
        x1_s1 = x_s1[a_s1 == 1]
        y0_s1 = y_s1[a_s1 == 0]
        y1_s1 = y_s1[a_s1 == 1]
        self.model.fit(y0_s1, x0_s1, y1_s1, x1_s1, *args, **kwargs)

    def fit(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, seed=None):

        if seed is None:
            seed = torch.randint(0, 1000000, (1,)).item()
        self.fit_seed = seed
        torch.manual_seed(seed)
        self.split_seed = torch.randint(0, 1000000, (1,)).item()
        # Split data into 2 parts
        x_s0, x_s1, y_s0, y_s1, a_s0, a_s1 = train_test_split(
            x, y, a, test_size=0.5, random_state=self.split_seed)

        # Fit first model
        self.split_fit(x_s0, y_s0, a_s0, x_s1, y_s1, a_s1)

        if self.fit_type == "Split":
            return

        # Else repeat process and fit on the other half
        # Copy over current models to models 0
        self.model0 = copy.deepcopy(self.model)
        self.nuisance_models0 = copy.deepcopy(self.nuisance_models)

        # Move over inital models to main model
        self.model = self.initial_model
        # Re-init all nuisance models
        for key, value in self.inital_nuisance_models.items():
            setattr(self, key+"_model", copy.deepcopy(value))

        # Refit with roles switched
        self.split_fit(x_s1, y_s1, a_s1, x_s0, y_s0, a_s0)

        self.model1 = self.model
        self.nuisance_models1 = copy.deepcopy(self.nuisance_models)

    def predict(self, y0: torch.Tensor, x: torch.Tensor):
        if self.fit_type == "Split":
            return self.model.predict(y0, x)
        else:
            y1 = self.model0.predict(y0, x)
            y2 = self.model1.predict(y0, x)
            return (y1 + y2) / 2


class CrossDRInverseEstimator(CrossInverseEstimator, DRInverseEstimator):
    def __init__(self, kernel: kernel.Kernel, propensity_model, cdf0_model, cdf1_model=None,
                 fit_type: Literal["Split", "Cross"] = "Split", **kwargs):
        super().__init__(kernel, propensity_model, cdf0_model, cdf1_model, fit_type=fit_type,  **kwargs)


class CrossIPWInverseEstimator(CrossInverseEstimator, IPWInverseEstimator):
    def __init__(self, kernel: kernel.Kernel, propensity_model,
                 fit_type: Literal["Split", "Cross"] = "Split", **kwargs):
        super().__init__(kernel, propensity_model, fit_type=fit_type,  **kwargs)


class CrossSeparateEstimator(CrossInverseEstimator, SeparateEstimator):
    def __init__(self, cdf0_model, cdf1_model=None,
                 fit_type: Literal["Split", "Cross"] = "Split", **kwargs):
        super().__init__(cdf0_model, cdf1_model, fit_type=fit_type,  **kwargs)


class CQTEEstimator(PseudoOutcomeEstimator):
    def __init__(self, inner_model, prob, propensity_model, icdf0_model, density0_model,
                 icdf1_model=None, density1_model=None, compatibility_mode=False, **kwargs):
        # Set up nuisance models
        self.compatibility_mode = compatibility_mode
        self.propensity_model = copy.deepcopy(propensity_model)

        self.icdf0_model = copy.deepcopy(icdf0_model)
        icdf1_model = icdf1_model if icdf1_model is not None else icdf0_model
        self.icdf1_model = copy.deepcopy(icdf1_model)

        self.density0_model = copy.deepcopy(density0_model)
        density1_model = density1_model if density1_model is not None else density0_model
        self.density1_model = copy.deepcopy(density1_model)

        self.nuisance_models = {"propensity": self.propensity_model,
                                "icdf0": self.icdf0_model, "icdf1": self.icdf1_model,
                                "density0": self.density0_model, "density1": self.density1_model}

        model = CQTE_Learner(inner_model, prob, self.icdf0_model, self.icdf1_model,
                             self.density0_model, self.density1_model, self.propensity_model)
        super().__init__(model=model, **kwargs)

    def update_prob(self, prob):
        self.model.prob = prob
        self._is_fitted = False
        self.model._is_fitted = False

    def fit_nuisance(self, x, y, a):
        # Fit model if it can be fit and is not already fitted
        if hasattr(self.propensity_model, "fit"):
            self.propensity_model.fit(x, a)
        if hasattr(self.icdf0_model, "fit"):
            self.icdf0_model.fit(x[a == 0], y[a == 0])
            self.icdf1_model.fit(x[a == 1], y[a == 1])
        if hasattr(self.density0_model, "fit"):
            self.density0_model.fit(x[a == 0])
            self.density1_model.fit(x[a == 1])

    def fit(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, seed=None):
        if seed is None:
            seed = torch.randint(0, 1000000, (1,)).item()
        self.fit_seed = seed
        torch.manual_seed(seed)
        # Fit nuisance models
        self.fit_nuisance(x, y, a)
        self.model.fit(x, y, a)

    def predict(self, x, x_extra=None, compatibility_mode=None):
        if compatibility_mode is None:
            compatibility_mode = self.compatibility_mode
        if compatibility_mode:
            x = x_extra
        return self.model.predict(x)


class CrossCQTEEstimator(CrossInverseEstimator, CQTEEstimator):
    def __init__(self, inner_model, prob, propensity_model, icdf0_model, density0_model,
                 icdf1_model=None, density1_model=None, compatibility_mode=False,
                 fit_type: Literal["Split", "Cross"] = "Split", **kwargs):

        super().__init__(inner_model, prob, propensity_model, icdf0_model, density0_model,
                         icdf1_model, density1_model, compatibility_mode, fit_type=fit_type, **kwargs)

    def split_fit(self, x_s0, y_s0, a_s0, x_s1, y_s1, a_s1, *args, **kwargs):
        self.fit_nuisance(x_s0, y_s0, a_s0)
        self.model.fit(x_s1, y_s1, a_s1)

    def predict(self, x: torch.Tensor, x_extra=None, compatibility_mode=None):
        if compatibility_mode is None:
            compatibility_mode = self.compatibility_mode
        if compatibility_mode:
            x = x_extra
        if self.fit_type == "Split":
            return self.model.predict(x)
        else:
            y1 = self.model0.predict(x)
            y2 = self.model1.predict(x)
            return (y1 + y2) / 2
