import numpy as np
from copy import deepcopy
from omegaconf import DictConfig
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
from MSM import MSM

class MSMPropensityTreatment(MSM):

    model_type = 'propensity_treatment'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: dict):
        
        super().__init__(args, dataset_collection)

        # input/output size
        self.input_size  = self.dim_treatments
        self.output_size = self.dim_treatments
        
        # propensity 
        self.propensity_treatment = MultiOutputClassifier(LogisticRegression(penalty='none', max_iter=args.exp.max_epochs))

    def get_inputs(self, dataset: dict) -> np.array:
        active_entries  = dataset['active_entries']
        prev_treatments = dataset['prev_treatments']
        
        inputs = (prev_treatments * active_entries).sum(1)

        return inputs

    def fit(self):
        # exploded dataset
        train_f = self.get_exploded_dataset(self.train_f, min_length = self.lag_features)
        #
        active_entries = train_f['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros((active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)
        
        self.propensity_treatment.fit(inputs, outputs)
        
class MSMPropensityHistory(MSM):

    model_type = 'propensity_history'

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: dict):
        
        super().__init__(args, dataset_collection)

        # input/output size
        self.input_size  = self.dim_treatments + self.dim_dosages + self.dim_outcome + self.dim_static_features
        self.output_size = self.dim_treatments

        self.propensity_history = MultiOutputClassifier(LogisticRegression(penalty='none', max_iter=args.exp.max_epochs))

        
    def get_inputs(self, dataset: dict, projection_horizon = 0) -> np.array:
        active_entries = dataset['active_entries']
        lagged_entries = active_entries - np.concatenate([active_entries[:, self.lag_features + 1:, :],
                            np.zeros((active_entries.shape[0], self.lag_features + 1, 1))], axis=1)
        if projection_horizon > 0:
            lagged_entries = np.concatenate([lagged_entries[:, projection_horizon:, :],
                                             np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        active_entries_before_protection = np.concatenate([active_entries[:, projection_horizon:, :],
                                                          np.zeros((active_entries.shape[0], projection_horizon, 1))], axis=1)

        prev_treatments = dataset['prev_treatments']
        prev_treatments = (prev_treatments * active_entries_before_protection).sum(1)
           
        #
        prev_dosages = dataset['prev_dosages']
        prev_dosages = prev_dosages[np.repeat(lagged_entries, self.dim_dosages, 2) == 1.0].reshape(prev_dosages.shape[0],
                                                                                                  (self.lag_features + 1) *
                                                                                                  self.dim_dosages) 
        prev_outputs = dataset['prev_outputs']
        prev_outputs = prev_outputs[np.repeat(lagged_entries, self.dim_outcome, 2) == 1.0].reshape(prev_outputs.shape[0],
                                                                                                  (self.lag_features + 1) *
                                                                                                  self.dim_outcome)
       
        static_features = dataset['static_features']
        
        inputs = [prev_treatments]
        inputs.append(prev_dosages)
        inputs.append(prev_outputs)
        inputs.append(static_features)
            
        return np.concatenate(inputs, axis=1)

    def fit(self):        
        train_f = self.get_exploded_dataset(self.train_f, min_length=self.lag_features)
        active_entries = train_f['active_entries']
        last_entries = active_entries - \
            np.concatenate([active_entries[:, 1:, :], np.zeros((active_entries.shape[0], 1, 1))], axis=1)

        # Inputs
        inputs = self.get_inputs(train_f)

        # Outputs
        current_treatments = train_f['current_treatments']
        outputs = (current_treatments * last_entries).sum(1)

        self.propensity_history.fit(inputs, outputs)