import torch
import random
import numpy as np
from abc import abstractmethod
from abc import ABC


def set_random_seed(seed, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


class BenchCompress(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def fit(self, source_full_scores, num_compressed_data, seed=42, *args, **kwargs):
        # num_model = source_full_scores.shape[0]
        # num_data = source_full_scores.shape[1]
        #
        # assert num_model > 1
        # assert num_data > 1
        # assert num_compressed_data > 1
        # assert num_data > num_compressed_data
        #
        # set_random_seed(seed)
        pass

    @abstractmethod
    def get_compressed_data_indices(self):
        pass

    @abstractmethod
    def predict(self, target_compressed_scores):
        # if len(target_compressed_scores.shape) == 1:
        #     target_compressed_scores = target_compressed_scores.reshape(1, -1)
        pass

    @abstractmethod
    def save(self, path_save):
        pass

    @abstractmethod
    def load(self, path_load):
        pass
