import abc

import numpy as np
import torch

from clustering.clustering_method import ClusteringMethod
from imputation_methods.ImputationMethod import ImputationMethod
from imputation_methods.imputation_utils import construct_histogram
from models.ClassificationModel import ClassProbabilities
from models.qr_models.QuantileRegression import QuantileRegression


class ConditionalSampleImputator(ImputationMethod):
    def __init__(self, clustering_method: ClusteringMethod, dataset, args):
        super().__init__()
        self.bin_edges = None
        self.proxy_cal = None
        self.saved_y = None
        self.clustering_method = clustering_method
        self.cal_clusters = None
        # self.qr_model = QuantileRegression(dataset.dataset_name, args.saved_models_path, dataset.x_dim, dataset.y_dim, args.alpha,
        #                            hidden_dims=args.hidden_dims, dropout=args.dropout, lr=args.lr, wd=args.wd,
        #                            device=args.device, figures_dir=args.figures_dir, seed=args.seed)
    @property
    def name(self):
        return f"{self.clustering_method.name}_sample"

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        super().fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,
                    epochs=epochs, batch_size=batch_size, n_wait=n_wait, **kwargs)
        # self.qr_model.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs, batch_size,
        #                         n_wait, **kwargs)

    def predict(self, x, z):
        device = x.device
        sampled_y2 = torch.zeros(len(x), device=device)
        test_clusters = self.clustering_method.predict_cluster(z)
        cal_clusters = self.cal_clusters
        n_clusters = max(torch.max(cal_clusters).int().item(), torch.max(test_clusters).int().item()) + 1
        for c in range(n_clusters):
            test_bin_idx = test_clusters == c
            cal_bin_idx = cal_clusters == c
            test_cluster_size = test_bin_idx.float().sum().int().item()
            cal_cluster_size = cal_bin_idx.float().sum().int().item()
            if cal_cluster_size == 0:
                if test_cluster_size != 0:
                    print(f"warning: tackled an empty calibration cluster ({cal_cluster_size}) while the test cluster is not empty ({test_cluster_size}), so sampling marginally instead")
                    sampled_y2[test_bin_idx] = torch.Tensor(
                        np.random.choice(self.saved_y.detach().cpu().numpy(), size=test_cluster_size)).to(
                        device)
                    continue
            if test_cluster_size == 0:
                continue
            sampled_y2[test_bin_idx] = torch.Tensor(
                np.random.choice(self.saved_y[cal_bin_idx].detach().cpu().numpy(), size=test_cluster_size)).to(device)
        return sampled_y2

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal)
        self.saved_y = y_cal[~deleted_cal].squeeze()
        self.proxy_cal = z_cal[~deleted_cal]
        x_cal = x_cal[~deleted_cal]
        y_cal = y_cal[~deleted_cal]
        # pred = self.qr_model.predict(x_cal).detach()
        # error = torch.max(pred[:, 0] - y_cal, y_cal - pred[:, 1])
        self.clustering_method.fit(self.proxy_cal, more_features=x_cal, y=y_cal)
        self.cal_clusters = self.clustering_method.predict_cluster(self.proxy_cal)



