import numpy as np
import joblib as jbl
from sklearn.linear_model import RidgeCV
from .base import BenchCompress, set_random_seed


class RandomSelection(BenchCompress):
    def __init__(self):
        super().__init__()
        self.compressed_data_indices = None

    def fit(self, source_full_scores, num_compressed_data, seed=42, *args, **kwargs):
        num_model = source_full_scores.shape[0]
        num_data = source_full_scores.shape[1]

        assert num_model > 1
        assert num_data > 1
        assert num_compressed_data > 1
        assert num_data > num_compressed_data

        set_random_seed(seed)

        self.compressed_data_indices = np.random.permutation(num_data)[
            :num_compressed_data
        ]
        return self

    def get_compressed_data_indices(self):
        return self.compressed_data_indices

    def predict(self, target_compressed_scores):
        if len(target_compressed_scores.shape) == 1:
            target_compressed_scores = target_compressed_scores.reshape(1, -1)

        return target_compressed_scores.mean(1)

    def save(self, path_save):
        jbl.dump(self.compressed_data_indices, path_save)

    def load(self, path_load):
        self.compressed_data_indices = jbl.load(path_load)
        return self


class RandomSelectionAndLearn(BenchCompress):
    def __init__(self):
        super().__init__()
        self.compressed_data_indices = None
        self.rgs = None

    def fit(self, source_full_scores, num_compressed_data, seed=42, *args, **kwargs):
        num_model = source_full_scores.shape[0]
        num_data = source_full_scores.shape[1]

        assert num_model > 1
        assert num_data > 1
        assert num_compressed_data > 1
        assert num_data > num_compressed_data

        set_random_seed(seed)

        self.compressed_data_indices = np.random.permutation(num_data)[
            :num_compressed_data
        ]

        self.rgs = RidgeCV()
        self.rgs.fit(
            source_full_scores[:, self.compressed_data_indices],
            source_full_scores.mean(1),
        )
        return self

    def get_compressed_data_indices(self):
        return self.compressed_data_indices

    def predict(self, target_compressed_scores):
        if len(target_compressed_scores.shape) == 1:
            target_compressed_scores = target_compressed_scores.reshape(1, -1)

        return self.rgs.predict(target_compressed_scores)

    def save(self, path_save):
        jbl.dump((self.compressed_data_indices, self.rgs), path_save)

    def load(self, path_load):
        self.compressed_data_indices, self.rgs = jbl.load(path_load)
        return self


class RandomSearchAndLearn(BenchCompress):
    def __init__(self):
        super().__init__()
        self.compressed_data_indices = None
        self.rgs = None

    def fit(
        self,
        source_full_scores,
        num_compressed_data,
        num_search=10000,
        seed=42,
        *args,
        **kwargs
    ):
        num_model = source_full_scores.shape[0]
        num_data = source_full_scores.shape[1]

        assert num_model > 1
        assert num_data > 1
        assert num_compressed_data > 1
        assert num_data > num_compressed_data

        set_random_seed(seed)

        order = np.random.permutation(num_model)
        tr_models = order[: int(num_model * 0.75)]
        val_models = order[int(num_model * 0.75) :]

        best_idxs, best_gap = None, 1e9
        for _ in range(num_search):
            selected_idxs = np.random.permutation(num_data)[:num_compressed_data]
            rgs = RidgeCV()
            rgs.fit(
                source_full_scores[tr_models][:, selected_idxs],
                source_full_scores.mean(1)[tr_models],
            )
            estimated_scores = rgs.predict(source_full_scores[:, selected_idxs])
            gap = np.fabs(
                estimated_scores[val_models] - source_full_scores.mean(1)[val_models]
            ).mean()
            if gap < best_gap:
                best_gap = gap
                best_idxs = selected_idxs

        self.compressed_data_indices = best_idxs
        self.rgs = RidgeCV()
        self.rgs.fit(
            source_full_scores[:, self.compressed_data_indices],
            source_full_scores.mean(1),
        )
        return self

    def get_compressed_data_indices(self):
        return self.compressed_data_indices

    def predict(self, target_compressed_scores):
        if len(target_compressed_scores.shape) == 1:
            target_compressed_scores = target_compressed_scores.reshape(1, -1)

        return self.rgs.predict(target_compressed_scores)

    def save(self, path_save):
        jbl.dump((self.compressed_data_indices, self.rgs), path_save)

    def load(self, path_load):
        self.compressed_data_indices, self.rgs = jbl.load(path_load)
        return self


class RandomRandomSelection(BenchCompress):
    def __init__(self):
        super().__init__()
        self.num_data = None
        self.num_compressed_data = None

    def fit(self, source_full_scores, num_compressed_data, seed=42, *args, **kwargs):
        num_model = source_full_scores.shape[0]
        num_data = source_full_scores.shape[1]

        assert num_model > 1
        assert num_data > 1
        assert num_compressed_data > 1
        assert num_data > num_compressed_data

        set_random_seed(seed)

        self.num_data = source_full_scores.shape[1]
        self.num_compressed_data = num_compressed_data
        return self

    def get_compressed_data_indices(self):
        return np.random.permutation(self.num_data)[: self.num_compressed_data]

    def predict(self, target_compressed_scores):
        if len(target_compressed_scores.shape) == 1:
            target_compressed_scores = target_compressed_scores.reshape(1, -1)

        return target_compressed_scores.mean(1)

    def save(self, path_save):
        jbl.dump((self.num_data, self.num_compressed_data), path_save)

    def load(self, path_load):
        self.num_data, self.num_compressed_data = jbl.load(path_load)
        return self


class RandomSelectionHalfLearn(BenchCompress):
    def __init__(self):
        super().__init__()
        self.compressed_data_indices = None
        self.rgs = None

    def fit(self, source_full_scores, num_compressed_data, seed=42, *args, **kwargs):
        num_model = source_full_scores.shape[0]
        num_data = source_full_scores.shape[1]

        assert num_model > 1
        assert num_data > 1
        assert num_compressed_data > 1
        assert num_data > num_compressed_data

        set_random_seed(seed)

        self.compressed_data_indices = np.random.permutation(num_data)[
            :num_compressed_data
        ]

        self.rgs = RidgeCV()
        self.rgs.fit(
            source_full_scores[:, self.compressed_data_indices],
            source_full_scores.mean(1),
        )
        return self

    def get_compressed_data_indices(self):
        return self.compressed_data_indices

    def predict(self, target_compressed_scores, eps=0.01):
        if len(target_compressed_scores.shape) == 1:
            target_compressed_scores = target_compressed_scores.reshape(1, -1)

        pred_part = self.rgs.predict(target_compressed_scores)
        data_part = target_compressed_scores.mean(-1)

        gap = np.fabs(data_part - pred_part)
        std = np.sqrt(target_compressed_scores.shape[1] * data_part * (1 - data_part))
        gap /= std

        pred_part[gap > eps] = data_part[gap > eps]

        return pred_part

    def save(self, path_save):
        jbl.dump((self.compressed_data_indices, self.rgs), path_save)

    def load(self, path_load):
        self.compressed_data_indices, self.rgs = jbl.load(path_load)
        return self


class RandomSelectionAndLearnAll(BenchCompress):
    def __init__(self):
        super().__init__()
        self.compressed_data_indices = None
        self.rgs = None

    def fit(self, source_full_scores, num_compressed_data, seed=42, *args, **kwargs):
        num_model = source_full_scores.shape[0]
        num_data = source_full_scores.shape[1]

        assert num_model > 1
        assert num_data > 1
        assert num_compressed_data > 1
        assert num_data > num_compressed_data

        set_random_seed(seed)

        self.compressed_data_indices = np.random.permutation(num_data)[
            :num_compressed_data
        ]

        self.rgs = RidgeCV()
        rest_indices = np.array(
            [i for i in range(num_data) if i not in self.compressed_data_indices]
        )
        self.rgs.fit(
            source_full_scores[:, self.compressed_data_indices],
            source_full_scores[:, rest_indices],
        )
        return self

    def get_compressed_data_indices(self):
        return self.compressed_data_indices

    def predict(self, target_compressed_scores):
        if len(target_compressed_scores.shape) == 1:
            target_compressed_scores = target_compressed_scores.reshape(1, -1)

        pred = self.rgs.predict(target_compressed_scores)
        n = pred.shape[1] + target_compressed_scores.shape[1]
        return (pred.sum(-1) + target_compressed_scores.sum(-1)) * 1.0 / n

    def save(self, path_save):
        jbl.dump((self.compressed_data_indices, self.rgs), path_save)

    def load(self, path_load):
        self.compressed_data_indices, self.rgs = jbl.load(path_load)
        return self
