import abc

import numpy as np
import torch
from tensordict import TensorDict

from imputation_methods.ImputationMethod import ImputationMethod


class SampleImputator(ImputationMethod):
    def __init__(self):
        super().__init__()
        self.saved_y = None

    @property
    def name(self):
        return "sample"

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        super().fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,
                                         epochs=epochs, batch_size=batch_size, n_wait=n_wait, **kwargs)

    def predict(self, x, y):
        device = x.device

        sampled_y2 = torch.Tensor(np.random.choice(self.saved_y.detach().cpu().numpy(), size=len(y))).to(device)
        return sampled_y2

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal)
        y_cal = y_cal[~deleted_cal]
        if isinstance(y_cal, TensorDict):
            y_cal = y_cal['y']
        self.saved_y = y_cal.squeeze()
        if len(self.saved_y.shape) == 0:
            self.saved_y = self.saved_y.unsqueeze(0)
        if len(self.saved_y) < 4:
            print(f"warning: {self.name} got only {len(self.saved_y)} valid calibration samples, skipping calibration phase")
            self.saved_y = torch.Tensor([0]).to(x_cal.device)




