import numpy as np

from .base_computer import BaseComputer
from .statistics import compute_expected_max

from sklearn.linear_model import LinearRegression

class QualityComputer(BaseComputer):
    def __init__(self, n_samples=100):
        """
        Initialize the QualityComputer object.

        Parameters:
            n_samples (int): The number of samples to be used for max quality computation of supermodels. 
                            Default is 100.
        """
        self.n_samples = n_samples
        super().__init__()
    
    def predict_supermodels(
            self, 
            questions,
            indices_models_supermodel,
            qualities,
            sigma_qualities,
            model_answers
    ):
        """
        Predicts the qualities of supermodels based on the given inputs.
        Args:
            questions (list): A list of questions.
            indices_models_supermodel (list): A list of indices indicating the supermodels to consider for each question.
            qualities (list): A list of qualities for each question and model.
            sigma_qualities (list): A list of covariance matrices for the qualities of each question and supermodel.
            model_answers (list): A list of model answers.
        Returns:
            qualities_supermodel (ndarray): An array of predicted qualities for each supermodel.
            qualities_var_supermodel (ndarray): An array of variances for the predicted qualities of each supermodel.
        """
        qualities_supermodel = []
        qualities_var_supermodel = []
        for i, question in enumerate(questions):
            qualities_sample = qualities[i][indices_models_supermodel[i]]
            if sigma_qualities[i] is None:
                sigma_qualities_sample = None
            else:
                sigma_qualities_sample = sigma_qualities[i][np.ix_(indices_models_supermodel[i], indices_models_supermodel[i])]
            
            qual, var = compute_expected_max(
                    qualities_sample,
                    sigma_qualities_sample,
                    independent=self.is_independent,
                    n_samples=self.n_samples
                )
            qualities_supermodel.append(qual)
            qualities_var_supermodel.append(var)
        return np.array(qualities_supermodel), np.array(qualities_var_supermodel)

class GroundTruthQualityComputer(QualityComputer):
    def __init__(self, noise_before_run=0.2, noise_after_run=0.05, n_samples=100):
        """
        Initializes the GroundTruthQualityComputer object.
        Computes the quality by adding noise to the ground truth quality values and 
        then fitting a linear model to the noisy values.

        Args:
            noise_before_run (float): The amount of noise before running the computation. Defaults to 0.2.
            noise_after_run (float): The amount of noise after running the computation. Defaults to 0.05.
            n_samples (int): The number of samples. Defaults to 100.
        """
        super().__init__(n_samples)
        self.noise_before_run = noise_before_run
        self.noise_after_run = noise_after_run
        self.quality_mapping = None
        self.sigmas = None

    def fit(self, questions, model_answers, measure):
        self.quality_mapping = dict()
        noisy_values = []
        for measure_value in measure:
            noisy_value = []
            for i in range(len(measure_value)):
                val = measure_value[i]
                noisy_value.append([
                    np.random.normal(val, self.noise_before_run),
                    np.random.normal(val, self.noise_after_run)
                ])
            noisy_value = np.array(noisy_value)
            noisy_values.append(noisy_value)

        noisy_values = np.array(noisy_values)

        actual_values = np.zeros(noisy_values.shape)

        self.sigmas = [[0, 0] for _ in range(measure.shape[1])]

        for model in range(noisy_values.shape[1]):
            for i in range(noisy_values.shape[2]):
                linear_model = LinearRegression()
                linear_model.fit(noisy_values[:, model, i].reshape(-1, 1), measure[:, model])
                actual_values[:, model, i] = linear_model.predict(noisy_values[:, model, i].reshape(-1, 1))
            
            self.sigmas[model][0] = np.std(actual_values[:, model, 0] - actual_values[:, model, 1])

        for q, a in zip(questions, actual_values):
            self.quality_mapping[q] = a

    def predict(self, questions, model_answers):
        qualities = []
        sigma_qualities = []
        for question, model_answer in zip(questions, model_answers):
            value = self.quality_mapping[question]
            value = np.array([
                value[i][0] if answer is None else value[i][1] for i, answer in enumerate(model_answer)
            ])
            sigma_noise = np.diag([
                self.sigmas[i][0] ** 2 if answer is None else 1e-6
                for i, answer in enumerate(model_answer)
            ])
            qualities.append(value)
            sigma_qualities.append(sigma_noise)
        
        return np.array(qualities), np.array(sigma_qualities)