""" Base class for experiments"""
import dataclasses

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error


@dataclasses.dataclass
class PreDataset:
    """Initial dataset with features and labels, with no performativity yet.

    Attributes:
    ----------
        features: `np.ndarray`
            Float array of shape [num_samples, num_features] of individual features for each unit
        group_features: `np.ndarray`
            Float array of shape [num_samples, num_features] of group features for each unit
        outcomes: `np.ndarray`
            Float array of shape[num_samples] containing the outcome (Y) for each unit

    """

    features: np.ndarray
    # group_features: np.ndarray
    outcomes: np.ndarray


@dataclasses.dataclass
class PostDataset:
    """Observational dataset produced by a simulation.

    Attributes:
    ----------
        features: `np.ndarray`
            Float array of shape [num_samples, num_features] of individual features for each unit
        treatments: `np.ndarray`
            Float array of shape [num_samples] containing treatment (model prediction Y hat) for each unit
        group_features: `np.ndarray`
            Float array of shape [num_samples, num_features] of group features for each unit
        outcomes: `np.ndarray`
            Float array of shape[num_samples] containing the outcome (Y) for each unit
        old_outcomes: `np.ndarray`
            Float array of shape[num_samples] containing the old outcomes (Y) for each unit (in the PreDataset)

    """

    features: np.ndarray
    treatments: np.ndarray
    # group_features: np.ndarray
    outcomes: np.ndarray
    old_outcomes: np.ndarray

    def get_features_with_yhat(self):
        new_x = np.concatenate(
            (self.features, np.expand_dims(self.treatments, axis=1)), axis=1
        )
        return new_x

    def split(self, test_size, random_state=42):
        (
            features_train,
            features_test,
            treatments_train,
            treatments_test,
            outcomes_train,
            outcomes_test,
            old_train,
            old_test,
        ) = train_test_split(
            self.features,
            self.treatments,
            self.outcomes,
            self.old_outcomes,
            test_size=test_size,
            random_state=random_state,
        )

        train_dataset = PostDataset(
            features=features_train,
            treatments=treatments_train,
            outcomes=outcomes_train,
            old_outcomes=old_train,
        )

        test_dataset = PostDataset(
            features=features_test,
            treatments=treatments_test,
            outcomes=outcomes_test,
            old_outcomes=old_test,
        )

        return train_dataset, test_dataset


class Simulation:
    """Encapsulate a simulation."""

    def __init__(
        self, name, predataset, predictor, params, step_performativity, graph=None
    ):
        self.name = name
        self.predataset = predataset
        self.predictor = predictor
        self.params = params
        self.step_performativity = step_performativity
        self.graph = graph

    def get_parameters(self):
        return self.params

    def run(self):
        results = self.step_performativity(
            predataset=self.predataset,
            predictor=self.predictor,
            params=self.params,
            graph=self.graph,
        )

        features, treatments, outcomes, old_outcomes = results

        return PostDataset(
            features=features,
            treatments=treatments,
            outcomes=outcomes,
            old_outcomes=old_outcomes,
        )


def compare_models(
    model1,
    model2,
    features1_train,
    features2_train,
    outcomes_train,
    features1_test,
    features2_test,
    outcomes_test,
    rmse=False,
):
    model1.fit(features1_train, outcomes_train)
    model2.fit(features2_train, outcomes_train)


    if rmse:
        return model1.score(features1_test, outcomes_test), \
               model2.score(features2_test, outcomes_test), \
               mean_squared_error(outcomes_test, model1.predict(features1_test), squared=False), \
               mean_squared_error(outcomes_test, model2.predict(features2_test))

    else:
        return model1.score(features1_test, outcomes_test), \
               model2.score(features2_test, outcomes_test), \
                None, None