

from abc import ABC, abstractmethod
import copy
from functools import partial
from scipy.stats import trim_mean
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator
import torch
from torch import nn
from torch import optim as opt
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import TensorDataset
from torch.utils.tensorboard import SummaryWriter
from typing import Union, Literal  # noqa: F401
# import sys
# sys.path.append("..")
# import losses
# from utils.data import infinite_loader
# from nonparamcdf import kernel_cdf
from .. import losses
from ..utils.data import infinite_loader
from .nonparamcdf import kernel_cdf


# class kernel_cdf:
#     pass


class GenericEstimator(nn.Module, ABC):
    """
    Generic class for gradient estimators.
    """
    def __init__(self, cqc_model: nn.Module, optimiser=opt.SGD, opt_args=None, device=None, log_dir=None, **kwargs):
        super().__init__()
        self.fit_type = "Full"
        self.model = copy.deepcopy(cqc_model)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        opt_args = opt_args if opt_args is not None else {}
        self.set_optimiser(optimiser, **opt_args)
        self.set_regulariser()
        self.set_scheduler()
        self.true_loss_diff = False
        self._opt_args = [optimiser, copy.deepcopy(opt_args)]
        if log_dir is not None:
            self.writer = SummaryWriter(log_dir=log_dir)
        else:
            self.writer = None

    def update_model(self, model: nn.Module):
        """
        Update the model with a new model.
        """
        self.model = copy.deepcopy(model)
        self._init_optim()

    @abstractmethod
    def loss(self, y0: torch.Tensor, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute the loss for the given model.
        """
        pass

    def true_loss(self, y0: torch.Tensor, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute the true loss for the given model, if known.
        """
        if self.true_loss_diff:
            raise NotImplementedError("True loss is different but has not implemented for this estimator")
        return self.loss(y0, x, y, a, **kwargs)

    def set_optimiser(self, optimiser, lr=0.01, **kwargs):
        self.optimiser: opt.Optimizer = optimiser(self.model.parameters(), lr=lr, **kwargs)

    def set_scheduler(self, scheduler: _LRScheduler = None, *scheduler_args, **scheduler_kwargs):
        if not hasattr(self, "optimiser"):
            raise ValueError("Optimiser must be set before setting the scheduler")
        if scheduler is None:
            self.scheduler = None
        else:
            self.scheduler = scheduler(self.optimiser, *scheduler_args, **scheduler_kwargs)
        self._scheduler_args = [scheduler, copy.deepcopy(scheduler_args), copy.deepcopy(scheduler_kwargs)]

    def set_regulariser(self, regulariser=None, *regulariser_args, **regulariser_kwargs):
        if regulariser is None:
            self.regulariser = None
        else:
            self.regulariser = regulariser(self.model, *regulariser_args, **regulariser_kwargs)
        self._regulariser_args = [regulariser, copy.deepcopy(regulariser_args), copy.deepcopy(regulariser_kwargs)]

    def _init_optim(self):
        # Re-init optimiser, scheduler, and regulariser
        self.set_optimiser(self._opt_args[0], **self._opt_args[1])
        self.set_scheduler(self._scheduler_args[0], *self._scheduler_args[1], **self._scheduler_args[2])
        self.set_regulariser(self._regulariser_args[0], *self._regulariser_args[1], **self._regulariser_args[2])

    def loss_wreg(self, *args):
        self.optimiser.zero_grad()
        outer_loss = self.loss(*args).mean()
        if self.regulariser is not None:
            if type(self.regulariser) is list:
                for reg in self.regulariser:
                    outer_loss += reg(*args)
            else:
                outer_loss += self.regulariser()
        return outer_loss

    def step(self, *args):
        outer_loss = self.loss_wreg(*args)
        outer_loss.backward()
        if isinstance(self.optimiser, opt.LBFGS):
            closure = partial(self.loss_wreg, *args)
            self.optimiser.step(closure)
        else:
            self.optimiser.step()
        return outer_loss

    def init_dataset(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor,
                     batch_size=None, y0_type: Literal["Fixed", "Random", "Conditional"] = "Random",
                     fixed_y0=None, prob_range=(0., 1.), seed=None, **kwargs):
        initial_seed = torch.initial_seed()
        seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
        torch.manual_seed(seed)
        loader_seed = torch.randint(0, 1000000, (1,)).item()
        if self.fit_type == "Valid":
            x, x_valid, y, y_valid, a, a_valid = train_test_split(
                x, y, a, test_size=0.2, random_state=seed)
            self.valid_dataset = TensorDataset(x_valid, y_valid, a_valid)
        self.dataset = TensorDataset(x, y, a)
        batch_size = batch_size if batch_size is not None else len(self.dataset)
        # Set up general data loader
        self.dataloader = infinite_loader(self.dataset, **kwargs, batch_size=batch_size, seed=loader_seed)

        # Set up y0 dataset
        self.y0_type = y0_type
        if y0_type == "Fixed":
            if fixed_y0 is None:
                raise ValueError("fixed_y0 must be provided if y_0_type is Fixed")
            if type(fixed_y0) is not torch.Tensor:
                self.fixed_y0 = torch.tensor(fixed_y0, device=self.device)
            y0_data = fixed_y0.repeat(batch_size)
        elif y0_type == "Random":
            # If prob_range is not (.0, 1.) then filter out the top and bottom quantiles from the prob_range
            if prob_range != (0., 1.):
                # Get quantiles of the data
                q1, q2 = torch.quantile(y0_data, torch.tensor(prob_range))
                # y0_data = y0_data[(y0_data >= q1) & (y0_data <= q2)]
            y0_data = y[a == 0]
            if y0_data.shape[0] < batch_size:
                # Randomly upsample y0_data
                y0_data = y0_data.repeat(batch_size // y0_data.shape[0] + 1)
            self.y0_data = y0_data
        elif y0_type == "Conditional":
            if not hasattr(self, "cdf0_model"):
                raise ValueError("cdf0_model must be provided if y_0_type is Conditional")
            self.prob_range = prob_range
        #     y0_data = torch.rand(batch_size) * (prob_range[1] - prob_range[0]) + prob_range[0]
        # y0_data = TensorDataset(y0_data)
        # self.y0loader = infinite_loader(y0_data, **kwargs, batch_size=batch_size)
        torch.manual_seed(initial_seed)

    def init_valid_dataset(self, valid_data=None, batch_size=None, **kwargs):
        if valid_data is None:
            self.cheatvalid_dataloader = None
            return
        else:
            self.cheatvalid_dataset = TensorDataset(*valid_data)
            self.cheatvalid_dataloader = torch.utils.data.DataLoader(
                self.cheatvalid_dataset, shuffle=False, drop_last=False, **kwargs, batch_size=batch_size)

    def get_y0s(self, batch, y0_gen=None):
        if self.y0_type == "Fixed":
            y0_val = self.fixed_y0.repeat(batch[1].shape[0])
        elif self.y0_type == "Random":
            y0_val = self.y0_data[torch.randperm(self.y0_data.shape[0], generator=y0_gen)][:batch[1].shape[0]]
        elif self.y0_type == "Conditional":
            alphas = (torch.rand(batch[1].shape[0], generator=y0_gen)
                      * (self.prob_range[1] - self.prob_range[0]) + self.prob_range[0])
            y0_val = self.cdf0_model.inverse_cdf(alphas, batch[0])
        return y0_val

    @staticmethod
    def snapshot_func(snapshot_freq, i):
        # Check for snapshot
        if type(snapshot_freq) is int:
            return (i % snapshot_freq == 0)
        elif type(snapshot_freq) is list:
            return (i in snapshot_freq)
        else:
            raise TypeError("snapshot_freq must be an int or a list of ints")

    def get_model_gradients(self):
        with torch.no_grad():
            return sum(torch.norm(p.grad.flatten(), 2) for p in self.model.parameters())

    def train(
        self, niters=10000,
        snapshot_freq=100, loss_avg_length=1,
        min_loss_val=-torch.inf, max_loss_val=torch.inf,
        true_loss=False, gradients=False, verbose=False, seed=None,
        **kwargs
    ):
        """Runs a training loop for the missing score matching model.

        Args:
            epochs (int, optional): Max number of runs through the data. Defaults to 1000.
            niters (int, optional): Max number of overall iterations. Defaults to 10000.
            snapshot_freq (int, optional): How often to store the values of the model. Defaults to 100.
            loss_avg_length (int, optional): Length of the moving average window for the loss. Defaults to 100.
            min_loss_val (float, optional): Minimum loss value to stop training. Defaults to -float("inf").
            verbose (bool, optional): Whether to print progress messages. Defaults to False.
            **score_args: Additional arguments to pass to the score model.

        Returns:
            list|none: If stored_vals is None, returns a list of the stored values. Otherwise, returns None
                and stores the results in stored_vals.
        """
        self.stored_vals = {"Losses": [],
                            "State_dicts": []}
        if true_loss:
            self.stored_vals["True_Losses"] = []
        if gradients:
            self.stored_vals["Gradients"] = []
        if self.cheatvalid_dataloader is not None:
            self.stored_vals["Valid_Losses"] = []
        if hasattr(self, "valid_dataset") and true_loss:
            self.stored_vals["True_Losses_Val"] = []
        temp_loss_list = []
        seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
        torch.manual_seed(seed)
        y0_gen_seed = torch.randint(0, 1000000, (1,)).item()
        y0_valid_gen_seed = torch.randint(0, 1000000, (1,)).item()
        y0_gen = torch.Generator(device=self.device)
        y0_valid_gen = torch.Generator(device=self.device)
        y0_gen.manual_seed(y0_gen_seed)
        y0_valid_gen.manual_seed(y0_valid_gen_seed)
        # Set up data and batch size
        for i in range(niters):
            # Send to device
            batch: list[torch.Tensor] = [dat.to(self.device) for dat in next(self.dataloader)]
            y0 = self.get_y0s(batch, y0_gen)
            outer_loss = self.step(y0, *batch)

            # Append loss
            temp_loss_list.append(outer_loss.item())
            if self.writer:
                self.writer.add_scalar('Loss/train', outer_loss.item(), i)
            ##########################
            # MARK: Validation Snapshot
            ###########################
            if self.snapshot_func(snapshot_freq, i):
                if ((outer_loss <= min_loss_val) | (outer_loss >= max_loss_val)) | torch.isnan(outer_loss):
                    # Raise an error due to loss breaking down
                    raise ValueError("Model stopped training due to nan or inf loss")

                if true_loss:
                    # If true loss is separate explicitly calculate it
                    if self.true_loss_diff:
                        true_loss = self.true_loss(y0, *batch)
                        if self.writer:
                            self.writer.add_histogram('True_Loss/train_hist', true_loss, i)
                        true_loss = trim_mean(true_loss.detach().numpy(), .05)
                    else:
                        true_loss = outer_loss

                    # self.stored_vals["True_Losses"].append(true_loss.item())
                    if self.writer:
                        self.writer.add_scalar('True_Loss/train', true_loss.item(), i)
                if self.writer:
                    self.writer.add_histogram('Predictions/hist', self.predict(y0, batch[0]), i)
                    self.writer.add_scalar('Gradients/train', self.get_model_gradients(), i)
                # Get average internal loss within epoch
                avg_loss = torch.mean(torch.tensor(temp_loss_list[-loss_avg_length:]))
                if verbose:
                    print(f"Iteration {i}: {avg_loss.item()}")
                self.stored_vals["Losses"].append(avg_loss.item())
                if gradients:
                    self.stored_vals["Gradients"].append({k: p.grad for k, p in self.model.named_parameters()})
                    if self.writer:
                        for n, p in self.model.named_parameters():
                            self.writer.add_histogram(f'Gradients/{n}', p.grad, i)

                # Do cheat validation if possible
                if self.cheatvalid_dataloader is not None:
                    val_loss = 0.
                    for val_batch in self.cheatvalid_dataloader:
                        y0_val, y1_val, x_val = [dat.to(self.device) for dat in val_batch]
                        val_loss += torch.sum(torch.abs(self.predict(y0_val, x_val) - y1_val)).item()
                    val_loss /= len(self.cheatvalid_dataset)
                    self.stored_vals["Valid_Losses"].append(val_loss)
                    if self.writer:
                        self.writer.add_scalar('Loss/valid', val_loss, i)

                # Do proper validation if possible
                if hasattr(self, "valid_dataset") and true_loss:
                    batch = self.valid_dataset.tensors
                    # Get y0
                    y0_val = self.get_y0s(batch, y0_valid_gen)

                    val_true_loss = self.true_loss(y0_val, *batch)
                    self.stored_vals["True_Losses_Val"].append(trim_mean(val_true_loss.detach().numpy(), .05))
                    if self.writer:
                        self.writer.add_scalar('True_Loss/valid', trim_mean(val_true_loss.detach().numpy(), .05), i)
                        self.writer.add_histogram('True_Loss/valid_hist',
                                                  val_true_loss, i)

                self.stored_vals["State_dicts"].append(
                    copy.deepcopy(self.model.state_dict()))

            if self.scheduler is not None:
                self.scheduler.step()

    def fit(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor,
            y0_type: Literal["Fixed", "Random", "Conditional"] = "Random", fixed_y0=None, prob_range=(0., 1.),
            scale=False, batch_size=None, niters=10000,
            snapshot_freq=100, loss_avg_length=1,
            valid_data: Union[None, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, valid_batch_size=None,
            min_loss_val=-torch.inf, max_loss_val=torch.inf,
            true_loss=False, gradients=False, seed=None, verbose=False, **kwargs):
        self.scale = scale
        if self.scale:
            self.x_mean = x.mean(dim=0)
            self.x_std = x.std(dim=0)
            x = (x - self.x_mean) / self.x_std
            self.y_mean = y.mean(dim=0)
            self.y_std = y.std(dim=0)
            y = (y - self.y_mean) / self.y_std

        if seed is None:
            seed = torch.randint(0, 1000000, (1,)).item()
        self.fit_seed = seed
        torch.manual_seed(seed)
        init_data_seed, train_seed = torch.randint(0, 1000000, (2,)).item()
        self.init_data_seed = init_data_seed
        self.train_seed = train_seed
        self.init_dataset(x, y, a, batch_size=batch_size,
                          y0_type=y0_type, fixed_y0=fixed_y0, prob_range=prob_range, seed=init_data_seed)
        self.init_valid_dataset(valid_data, batch_size=valid_batch_size)

        self._init_optim()
        self.train(snapshot_freq=snapshot_freq, niters=niters,
                   loss_avg_length=loss_avg_length, min_loss_val=min_loss_val,
                   max_loss_val=max_loss_val, true_loss=true_loss, verbose=verbose,
                   seed=train_seed, gradients=gradients, **kwargs)

    def predict(self, y0: torch.Tensor, x: torch.Tensor):
        """
        Predict the outcome for a given set of covariates and treatment assignment.
        """
        if self.scale:
            x = (x - self.x_mean) / self.x_std
            y0 = (y0 - self.y_mean) / self.y_std
        y1 = self.model(y0, x)
        if self.scale:
            y1 = y1 * self.y_std + self.y_mean
        return y1

    def load_state_dict(self, state_dict, strict=True, assign=False):
        self.model.load_state_dict(state_dict, strict, assign)


class BasicEstimator(GenericEstimator):
    def loss(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, y0: torch.Tensor):
        y1 = self.model(y0, x)
        return losses.grad_loss(y, a, y0, y1)


class EstimatorwithNuisance(GenericEstimator):
    @abstractmethod
    def __init__(self, cqc_model: nn.Module, optimiser=opt.SGD, opt_args=None, device=None, log_dir=None, **kwargs):
        super().__init__(cqc_model, optimiser, opt_args, device, log_dir, **kwargs)
        self.nuisance_models = {}

    @abstractmethod
    def fit_nuisance(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor):
        """
        Fit the nuisance models using the data.
        """
        pass

    def fit(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor,
            y0_type: Literal["Fixed", "Random", "Conditional"] = "Random", fixed_y0=None, prob_range=(0., 1.),
            scale=False, batch_size=None, niters=10000, snapshot_freq=100, loss_avg_length=1,
            valid_data: Union[None, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, valid_batch_size=None,
            min_loss_val=-torch.inf, max_loss_val=torch.inf,
            true_loss=False, seed=None, verbose=False, **kwargs):
        initial_seed = torch.initial_seed()
        self.scale = scale
        if self.scale:
            self.x_mean = x.mean(dim=0)
            self.x_std = x.std(dim=0)
            x = (x - self.x_mean) / self.x_std
            self.y_mean = y.mean(dim=0)
            self.y_std = y.std(dim=0)
            y = (y - self.y_mean) / self.y_std

        if seed is None:
            seed = torch.randint(0, 1000000, (1,)).item()
        self.fit_seed = seed
        self.dataloader_seed = torch.randint(0, 1000000, (1,)).item()
        self.train_seed = torch.randint(0, 1000000, (1,)).item()

        # Fit nuisance models
        self.fit_nuisance(x, y, a)

        self.init_dataset(x, y, a, batch_size=batch_size,
                          y0_type=y0_type, fixed_y0=fixed_y0, prob_range=prob_range, seed=self.dataloader_seed)
        self.init_valid_dataset(valid_data, batch_size=valid_batch_size)

        self.train(snapshot_freq=snapshot_freq, niters=niters,
                   loss_avg_length=loss_avg_length, min_loss_val=min_loss_val,
                   max_loss_val=max_loss_val, true_loss=true_loss, verbose=verbose, seed=self.train_seed, **kwargs)
        torch.manual_seed(initial_seed)


class IPWEstimator(EstimatorwithNuisance):
    def __init__(self, cqc_model: nn.Module, propensity_model: BaseEstimator, optimiser=opt.SGD, opt_args=None,
                 device=None, log_dir=None, **kwargs):
        super().__init__(cqc_model, optimiser, opt_args, device, log_dir, **kwargs)
        self.propensity_model = copy.deepcopy(propensity_model)
        self.nuisance_models = {"propensity": self.propensity_model}

    def init_dataset(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor,
                     batch_size=None, y0_type: Literal["Fixed", "Random"] = "Random",
                     fixed_y0=None, prob_range=(0., 1.), seed=None, **kwargs):
        initial_seed = torch.initial_seed()
        seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
        torch.manual_seed(seed)
        loader_seed = torch.randint(0, 1000000, (1,)).item()

        propensities = self.get_propensities(x, y, a)
        self.propensities = propensities
        if self.fit_type == "Valid":
            x, x_valid, y, y_valid, a, a_valid, propensities, propensities_valid = train_test_split(
                x, y, a, propensities, test_size=0.2, random_state=seed)
            self.valid_dataset = TensorDataset(x_valid, y_valid, a_valid, propensities_valid)
        self.dataset = TensorDataset(x, y, a, propensities)
        batch_size = batch_size if batch_size is not None else len(self.dataset)
        # Set up general data loader
        self.dataloader = infinite_loader(self.dataset, **kwargs, batch_size=batch_size, seed=loader_seed)

        # Set up y0 dataset
        self.y0_type = y0_type
        if y0_type == "Fixed":
            if fixed_y0 is None:
                raise ValueError("fixed_y0 must be provided if y_0_type is Fixed")
            if type(fixed_y0) is not torch.Tensor:
                self.fixed_y0 = torch.tensor(fixed_y0, device=self.device)
            else:
                self.fixed_y0 = fixed_y0.to(self.device)
            y0_data = fixed_y0.repeat(batch_size)
        elif y0_type == "Random":
            y0_data = y[a == 0]
            # If prob_range is not (.0, 1.) then filter out the top and bottom quantiles from the prob_range
            if prob_range != (0., 1.):
                # Get quantiles of the data
                q1, q2 = torch.quantile(y0_data, torch.tensor(prob_range))
                y0_data = y0_data[(y0_data >= q1) & (y0_data <= q2)]
            if y0_data.shape[0] < batch_size:
                # Randomly upsample y0_data
                y0_data = y0_data.repeat(batch_size // y0_data.shape[0] + 1)
            self.y0_data = y0_data
        # y0_data = TensorDataset(y0_data)
        # self.y0loader = infinite_loader(y0_data, **kwargs, batch_size=batch_size)
        torch.manual_seed(initial_seed)

    def fit_nuisance(self, x, y, a):
        # Fit model if it can be fit and is not already fitted
        if hasattr(self.propensity_model, "fit"):
            self.propensity_model.fit(x, a)

    def get_propensities(self, x, y, a):
        if hasattr(self.propensity_model, "predict"):
            return self.propensity_model.predict(x)
        else:
            return self.propensity_model(x)

    def loss(self, y0: torch.Tensor, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, propensities: torch.Tensor):
        y1 = self.model(y0, x)
        return losses.ipw_loss(y, a, y0, y1, propensities)

    true_loss = loss


class DREstimator(EstimatorwithNuisance):
    def __init__(self, cqc_model: nn.Module, propensity_model: BaseEstimator, cdf0_model: kernel_cdf,
                 cdf1_model: kernel_cdf = None, optimiser=opt.SGD, opt_args=None, device=None, log_dir=None, **kwargs):
        super().__init__(cqc_model, optimiser, opt_args, device, log_dir, **kwargs)
        self.propensity_model = copy.deepcopy(propensity_model)
        self.cdf0_model = copy.deepcopy(cdf0_model)
        cdf1_model = cdf1_model if cdf1_model is not None else cdf0_model
        self.cdf1_model = copy.deepcopy(cdf1_model)
        self.nuisance_models = {"propensity": self.propensity_model,
                                "cdf0": self.cdf0_model,
                                "cdf1": self.cdf1_model}
        self.true_loss_diff = True

    def fit_nuisance(self, x, y, a):
        # Fit model if it can be fit and is not already fitted
        if hasattr(self.propensity_model, "fit"):
            self.propensity_model.fit(x, a)
        if hasattr(self.cdf0_model, "fit"):
            self.cdf0_model.fit(x[a == 0], y[a == 0])
            self.cdf1_model.fit(x[a == 1], y[a == 1])

    # define cdf1_func as a numpy function
    def single_cdf1_func(self, y1: torch.Tensor, x: torch.Tensor, w1: Union[torch.Tensor, None] = None):
        if w1 is None:
            return self.cdf1_model(torch.tensor([y1]), x.unsqueeze(0)).detach().numpy()
        else:
            return self.cdf1_model(torch.tensor([y1]), x.unsqueeze(0), w1.unsqueeze(0)).detach().numpy()

    def init_dataset(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor,
                     batch_size=None, y0_type: Literal["Fixed", "Random"] = "Random",
                     fixed_y0=None, prob_range=(0., 1.), seed=None, **kwargs):
        initial_seed = torch.initial_seed()
        seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
        torch.manual_seed(seed)
        loader_seed = torch.randint(0, 1000000, (1,)).item()

        self.min_integral_val = torch.min(y).item()-1
        propensities = self.get_propensities(x, y, a)
        self.propensities = propensities
        w0, w1 = None, None
        if hasattr(self.cdf0_model, "get_y_weights"):
            w0 = self.cdf0_model.get_y_weights(x)
            w1 = self.cdf1_model.get_y_weights(x)

        # Optional split data into train and valid
        if self.fit_type == "Valid":
            if (w0 is not None) & (w1 is not None):
                (x, x_valid, y, y_valid, a, a_valid, propensities, p_valid,
                 w0, w0_valid, w1, w1_valid
                 ) = train_test_split(
                    x, y, a, propensities, w0, w1, test_size=0.2, random_state=seed)
                self.valid_dataset = TensorDataset(x_valid, y_valid, a_valid, p_valid, w0_valid, w1_valid)
            else:
                x, x_valid, y, y_valid, a, a_valid, propensities, p_valid = train_test_split(
                    x, y, a, propensities, test_size=0.2, random_state=seed)
                self.valid_dataset = TensorDataset(x_valid, y_valid, a_valid, p_valid)

        if (w0 is not None) & (w1 is not None):
            self.dataset = TensorDataset(x, y, a, propensities, w0, w1)
        else:
            self.dataset = TensorDataset(x, y, a, propensities)
        batch_size = batch_size if batch_size is not None else len(self.dataset)
        # Set up general data loader
        self.dataloader = infinite_loader(self.dataset, **kwargs, batch_size=batch_size, seed=loader_seed)

        # Set up y0 dataset
        self.y0_type = y0_type
        if y0_type == "Fixed":
            if fixed_y0 is None:
                raise ValueError("fixed_y0 must be provided if y_0_type is Fixed")
            if type(fixed_y0) is not torch.Tensor:
                self.fixed_y0 = torch.tensor(fixed_y0, device=self.device)
            else:
                self.fixed_y0 = fixed_y0.to(self.device)
            # y0_data = fixed_y0.repeat(batch_size)
        elif y0_type == "Random":
            y0_data = y[a == 0]
            # If prob_range is not (.0, 1.) then filter out the top and bottom quantiles from the prob_range
            if prob_range != (0., 1.):
                # Get quantiles of the data
                q1, q2 = torch.quantile(y0_data, torch.tensor(prob_range))
                y0_data = y0_data[(y0_data >= q1) & (y0_data <= q2)]
            if y0_data.shape[0] < batch_size:
                # Randomly upsample y0_data
                y0_data = y0_data.repeat(batch_size // y0_data.shape[0] + 1)
            self.y0_data = y0_data
        elif y0_type == "Conditional":
            self.prob_range = prob_range
            # y0_data = torch.rand(batch_size) * (prob_range[1] - prob_range[0]) + prob_range[0]
        # y0_data = TensorDataset(y0_data)
        # self.y0loader = infinite_loader(y0_data, **kwargs, batch_size=batch_size)
        torch.manual_seed(initial_seed)

    def get_propensities(self, x, y, a):
        if hasattr(self.propensity_model, "predict"):
            return self.propensity_model.predict(x)
        else:
            return self.propensity_model(x)

    def get_cdfs(self, x, y0, y1, w0=None, w1=None):
        if w0 is not None:
            cdf0 = self.cdf0_model(y0, x, w0)
            cdf1 = self.cdf1_model(y1, x, w1)
        else:
            cdf0 = self.cdf0_model(y0, x)
            cdf1 = self.cdf1_model(y1, x)
        return cdf0, cdf1

    def loss(self, y0: torch.Tensor, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, propensities: torch.Tensor,
             w0=None, w1=None):
        # sample
        # Get y0
        y1 = self.model(y0, x)
        cdf0s, cdf1s = self.get_cdfs(x, y0, y1, w0, w1)
        return losses.dr_loss(y, a, y0, y1, propensities, cdf0s, cdf1s)

    def true_loss(self, y0: torch.Tensor, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor, propensities: torch.Tensor,
                  w0=None, w1=None):
        y1 = self.model(y0, x)
        cdf0s, _ = self.get_cdfs(x, y0, y1, w0, w1)

        return losses.true_dr_loss(y, x, a, y0, y1, propensities, cdf0s, self.single_cdf1_func)


class CrossEstimator(EstimatorwithNuisance):
    def __init__(self, cqc_model: nn.Module, *args, fit_type: Literal["Split", "Cross", "Valid"] = "Split", **kwargs):
        super().__init__(cqc_model, *args, **kwargs)
        self.fit_type = fit_type
        self.initial_model = copy.deepcopy(self.model)
        self.inital_nuisance_models = copy.deepcopy(self.nuisance_models)

    def split_fit(self, x_s0, y_s0, a_s0, x_s1, y_s1, a_s1, y0_type, fixed_y0, prob_range,
                  batch_size, valid_data, valid_batch_size, seed=None, *args, **kwargs):
        initial_seed = torch.initial_seed()
        seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
        torch.manual_seed(seed)
        self.datainit_seed = torch.randint(0, 1000000, (1,)).item()
        self.train_seed = torch.randint(0, 1000000, (1,)).item()

        self.fit_nuisance(x_s0, y_s0, a_s0)

        self.init_dataset(x_s1, y_s1, a_s1, batch_size=batch_size,
                          y0_type=y0_type, fixed_y0=fixed_y0, prob_range=prob_range, seed=self.datainit_seed)
        self.init_valid_dataset(valid_data, batch_size=valid_batch_size)
        self._init_optim()
        self.train(*args, **kwargs, seed=self.train_seed)
        torch.manual_seed(initial_seed)

    def fit(self, x: torch.Tensor, y: torch.Tensor, a: torch.Tensor,
            y0_type: Literal["Fixed", "Random", "Conditional"] = "Random", fixed_y0=None, prob_range=(0., 1.),
            scale=False, batch_size=None, niters=10000,
            snapshot_freq=100, loss_avg_length=1,
            valid_data: Union[None, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, valid_batch_size=None,
            min_loss_val=-torch.inf, max_loss_val=torch.inf,
            true_loss=False, gradients=False, verbose=False, seed=None, **kwargs):
        initial_seed = torch.initial_seed()
        self.scale = scale
        if self.scale:
            self.x_mean = x.mean(dim=0)
            self.x_std = x.std(dim=0)
            x = (x - self.x_mean) / self.x_std
            self.y_mean = y.mean(dim=0)
            self.y_std = y.std(dim=0)
            y = (y - self.y_mean) / self.y_std

        seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
        self.fit_seed = seed
        torch.manual_seed(seed)
        # get split seeds
        self.split_seed = torch.randint(0, 1000000, (1,)).item()
        self.splitfit_seeds = torch.randint(0, 1000000, (2,)).tolist()
        # Split data into 2 parts
        ind_s0, ind_s1, x_s0, x_s1, y_s0, y_s1, a_s0, a_s1 = train_test_split(
            torch.arange(0, y.shape[0]), x, y, a, test_size=0.5, random_state=self.split_seed)
        self.split_ind0 = ind_s0
        self.split_ind1 = ind_s1
        # Fit nuisance models
        self.split_fit(x_s0, y_s0, a_s0, x_s1, y_s1, a_s1,
                       y0_type=y0_type, fixed_y0=fixed_y0, prob_range=prob_range,
                       batch_size=batch_size,
                       valid_data=valid_data, valid_batch_size=valid_batch_size,
                       niters=niters, snapshot_freq=snapshot_freq,
                       loss_avg_length=loss_avg_length, min_loss_val=min_loss_val, max_loss_val=max_loss_val,
                       true_loss=true_loss, gradients=gradients, verbose=verbose, seed=self.splitfit_seeds[0], **kwargs)

        if self.fit_type in ["Split", "Valid"]:
            torch.manual_seed(initial_seed)
            return

        # Else repeat process and fit on the other half
        # Copy over current models to models 0
        self.model0 = copy.deepcopy(self.model)
        self.nuisance_models0 = copy.deepcopy(self.nuisance_models)
        self.stored_vals0 = copy.deepcopy(self.stored_vals)

        # Move over inital models to main model
        self.model = self.initial_model
        # Re-init all optimisers and schedulers etc
        for key, value in self.inital_nuisance_models.items():
            setattr(self, key+"_model", copy.deepcopy(value))

        # Refit with roles switched
        self.split_fit(x_s1, y_s1, a_s1, x_s0, y_s0, a_s0,
                       y0_type=y0_type, fixed_y0=fixed_y0, prob_range=prob_range,
                       batch_size=batch_size,
                       valid_data=valid_data, valid_batch_size=valid_batch_size,
                       niters=niters, snapshot_freq=snapshot_freq,
                       loss_avg_length=loss_avg_length, min_loss_val=min_loss_val, max_loss_val=max_loss_val,
                       true_loss=true_loss, verbose=verbose, seed=self.splitfit_seeds[0], **kwargs)

        self.model1 = self.model
        self.nuisance_models1 = copy.deepcopy(self.nuisance_models)
        self.stored_vals1 = copy.deepcopy(self.stored_vals)
        torch.manual_seed(initial_seed)

    def predict(self, y0: torch.Tensor, x: torch.Tensor):
        if self.scale:
            x = (x - self.x_mean) / self.x_std
            y0 = (y0 - self.y_mean) / self.y_std
        if self.fit_type in ["Split", "Valid"]:
            out = self.model(y0, x)
        else:
            out_1 = self.model0(y0, x)
            out_2 = self.model1(y0, x)
            out = (out_1 + out_2) / 2
        if self.scale:
            out = out * self.y_std + self.y_mean
        return out

    def load_state_dict(self, state_dict1, state_dict2=None, strict=True, assign=False):
        if self.fit_type == ["Split", "Valid"]:
            state_dict = state_dict1
            self.model.load_state_dict(state_dict, strict, assign)
        else:
            if state_dict2 is None:
                raise ValueError("2 State dicts (one for each fitted model) must be provided if fit_type is Cross")
            self.model0.load_state_dict(state_dict1, strict, assign)
            self.model1.load_state_dict(state_dict2, strict, assign)


class CrossDREstimator(CrossEstimator, DREstimator):
    def __init__(self, cqc_model: nn.Module, propensity_model: BaseEstimator, cdf0_model: kernel_cdf,
                 cdf1_model: kernel_cdf = None, optimiser=opt.SGD, opt_args=None,
                 fit_type: Literal["Split", "Cross"] = "Split", device=None, log_dir=None, **kwargs):
        super().__init__(cqc_model, propensity_model, cdf0_model, cdf1_model, optimiser, opt_args,
                         fit_type=fit_type, device=device, log_dir=log_dir, **kwargs)


class CrossIPWEstimator(CrossEstimator, IPWEstimator):
    def __init__(self, cqc_model: nn.Module, propensity_model: BaseEstimator,
                 optimiser=opt.SGD, opt_args=None, fit_type: Literal["Split", "Cross"] = "Split",
                 device=None, log_dir=None, **kwargs):
        super().__init__(cqc_model, propensity_model, optimiser, opt_args,
                         fit_type=fit_type, device=device, log_dir=log_dir, **kwargs)
