import numpy as np
from copy import deepcopy
from tqdm import tqdm

from omegaconf import DictConfig
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor

from MSM import MSM
from MSMPropensity import MSMPropensityTreatment, MSMPropensityHistory


class MSMRegressor(MSM):

    model_type = 'msm_regressor'

    def __init__(self,
                 args: DictConfig,
                 propensity_treatment: MSMPropensityTreatment = None,
                 propensity_history: MSMPropensityHistory = None,
                 dataset_collection: dict = None):
        
        super().__init__(args, dataset_collection)

        #
        self.input_size = self.dim_treatments + self.dim_dosages + self.dim_outcome + self.dim_static_features
        self.output_size = self.dim_outcome

        #
        self.propensity_treatment = propensity_treatment
        self.propensity_history = propensity_history

        # regressor
        self.msm_regressor = \
            [MultiOutputRegressor(LinearRegression()) for _ in range(self.projection_horizon + 1)]


    def get_inputs(self, dataset: dict, projection_horizon = 0, tau = 0) -> np.array:
        active_entries = dataset['active_entries']
        
        #
        # 
        #
        lagged_entries = active_entries - np.concatenate([active_entries[:, self.lag_features + 1:, :],
                                                          np.zeros((active_entries.shape[0], self.lag_features + 1, 1))], 1)
        
        if projection_horizon > 0:
            lagged_entries = np.concatenate([lagged_entries[:, projection_horizon:, :],
                                             np.zeros((active_entries.shape[0], projection_horizon, 1))], 1)
        #
        prev_outputs = dataset['prev_outputs']
        prev_outputs = prev_outputs[np.repeat(lagged_entries, self.dim_outcome, axis = 2) == 1.0].reshape(prev_outputs.shape[0],
                                                                                                          (self.lag_features + 1) * self.dim_outcome)
        prev_dosages = dataset['prev_dosages']
        prev_dosages = prev_dosages[np.repeat(lagged_entries, self.dim_dosages, axis = 2) == 1.0].reshape(prev_dosages.shape[0],
                                                                                           (self.lag_features + 1) * self.dim_dosages)
        
        # prev_treatments
        active_entries_before_protection = np.concatenate([active_entries[:, projection_horizon:, :],
                                                          np.zeros((active_entries.shape[0], projection_horizon, 1))], 1)        
        #
        prev_treatments = dataset['prev_treatments']
        prev_treatments = (prev_treatments * active_entries_before_protection).sum(1)
        
        #
        static_features = dataset['static_features']

        # Adding current actions
        prediction_entries = active_entries - np.concatenate(
            [active_entries[:, tau + 1:, :], np.zeros((active_entries.shape[0], tau + 1, 1))], axis=1)
        prediction_entries = np.concatenate([prediction_entries[:, projection_horizon - tau:, :],
                                             np.zeros((prediction_entries.shape[0], projection_horizon - tau, 1))], axis=1)

        current_treatments = dataset['current_treatments']
        current_treatments = (current_treatments * prediction_entries).sum(1)
      
        current_dosages = dataset['current_dosages']
        current_dosages = (current_dosages * prediction_entries).sum(1)
         
        inputs = [prev_treatments]
        inputs.append(prev_dosages)
        inputs.append(prev_outputs)
        inputs.append(static_features)
        inputs.append(current_treatments)
        inputs.append(current_dosages)
        
        return np.concatenate(inputs, axis=1)

    def get_sample_weights(self, dataset: dict, tau=0) -> np.array:
        active_entries = dataset['active_entries']
        stabilized_weights = dataset['stabilized_weights']

        prediction_entries = active_entries - np.concatenate(
            [active_entries[:, tau + 1:, :], np.zeros((active_entries.shape[0], tau + 1, 1))],
            axis=1)
        stabilized_weights = stabilized_weights[np.squeeze(prediction_entries) == 1.0].reshape(stabilized_weights.shape[0],
                                                                                               tau + 1)
        sw = np.prod(stabilized_weights, axis=1)
        sw_tilde = np.clip(sw, np.nanquantile(sw, 0.01), np.nanquantile(sw, 0.99))
        return sw_tilde

    def prepare_data(self) -> None:
        prop_treat_train_f = self.propensity_treatment.get_propensity_scores(self.train_f)
        prop_hist_train_f  = self.propensity_history.get_propensity_scores(self.train_f)
        self.train_f['stabilized_weights'] = np.prod(prop_treat_train_f / prop_hist_train_f, axis=2)
        
    def fit(self):
        self.prepare_data()
        for tau in range(self.projection_horizon + 1):
            train_f = self.get_exploded_dataset(self.train_f, min_length = self.lag_features + tau)

            active_entries = train_f['active_entries']
            last_entries = active_entries - \
                np.concatenate([active_entries[:, 1:, :], np.zeros((active_entries.shape[0], 1, 1))], axis=1)

            # Inputs
            inputs = self.get_inputs(train_f, projection_horizon = tau, tau = tau)

            # Stabilized weights
            sw = self.get_sample_weights(train_f, tau = tau)

            # Outputs
            outputs = train_f['outputs']
            outputs = (outputs * last_entries).sum(1)

            self.msm_regressor[tau].fit(inputs, outputs, sample_weight = sw)

    
    def get_one_step_factual_rmse(self, dataset_name):
        # set data
        if dataset_name == 'train_f':
            dataset  = self.train_f
        elif dataset_name == 'valid_f':
            dataset  = self.valid_
        elif dataset_name == 'test_cf':
            dataset  = self.test_cf
        else:
            raise NotImplementedError()
        
        # one-step prediction (factual treatment)
        outcome_pred = self.get_predictions(dataset)

        # unscaled       
        train_means, train_stds = self.train_scaling_params
        outcome_pred = outcome_pred * train_stds["cancer_volume"] + train_means["cancer_volume"]

        # calculate mse loss
        mse     = ((outcome_pred - dataset["unscaled_x_next"]) ** 2) * dataset['active_entries']
        mse_all = mse.sum() / dataset['active_entries'].sum()
        rmse_normalised_all = 100.0 * (np.sqrt(mse_all) / self.MAX_CANCER_VOLUME)

        return rmse_normalised_all
    
    def get_predictions(self, dataset: dict) -> np.array:
        batch_size = 10000
        
        npatients, seq_lengths, num_features = dataset['prev_treatments'].shape
        outcome_pred = np.zeros([npatients, seq_lengths, 1])
        
        for batch in range(len(dataset) // batch_size + 1):
            subset = deepcopy(dataset)
            for (k, v) in subset.items():
                subset[k] = v[batch * batch_size:(batch + 1) * batch_size]
            
            #
            # Exploded dataset
            # 
            exploded_dataset = self.get_exploded_dataset(subset, 
                                                         min_length = self.lag_features, 
                                                         only_active_entries = False,
                                                         #max_length=max(dataset['seq_lengths'][:])
                                                        )
            # inputs
            inputs = self.get_inputs(exploded_dataset, projection_horizon = 0, tau = 0)
            
            # predict
            outcome_pred_batch = self.msm_regressor[0].predict(inputs)

            
            outcome_pred_batch = outcome_pred_batch.reshape(subset['active_entries'].shape[0],
                                                            subset['active_entries'].shape[1] - 1,
                                                            self.dim_outcome)
            
            # First time-step requires two previous outcomes -> duplicating the next prediction
            outcome_pred_batch = np.concatenate([outcome_pred_batch[:, :1, :], outcome_pred_batch], 1)
            outcome_pred[batch * batch_size:(batch + 1) * batch_size] = outcome_pred_batch
            
        return outcome_pred

    def get_autoregressive_predictions(self, dataset: dict) -> np.array:
        predicted_outputs = np.zeros((dataset['prev_outputs'].shape[0], 
                                      self.projection_horizon, 
                                      self.dim_outcome))

        for t in range(1, self.projection_horizon + 1):
            inputs = self.get_inputs(dataset, 
                                     projection_horizon = self.projection_horizon - 1, 
                                     tau = t - 1)

            outcome_pred = self.msm_regressor[t].predict(inputs)
            predicted_outputs[:, t - 1] = outcome_pred
            
            
            if t < self.projection_horizon:
                dataset['prev_outputs'][:, t, :] = outcome_pred
            
        return predicted_outputs
    
    
    def get_multi_step_counterfactual_rmse(self):        
        # set data
        dataset_f  = self.test_cf
        dataset_cf = self.test_cf_multi
        
        # shape
        npatients, seq_lengths, nsamples, ntaus, _ = dataset_cf['unscaled_x_nahead'].shape
        predict_x_cf = np.zeros([npatients, seq_lengths, nsamples, ntaus, 1])
        
        # statci features for cf 
        # projection_horizon分延長
        static_features    = dataset_f['inp_v']
        static_cf_features = deepcopy(static_features)
        static_cf_features = np.concatenate([static_cf_features, static_cf_features[:, 0:ntaus-1, :]], 1)
                
        #
        # one-step prediction (factual treatment) 
        #
        outcome_pred = self.get_predictions(dataset_f)
        
        # encoderの分岐は一つなのでサンプル数だけ拡張
        for nsample in range(nsamples):
            predict_x_cf[:, :, nsample, 0, 0] = np.squeeze(outcome_pred)

        # multi-step prediction (couterfactual treatment)
        data_seq = {}
        for t in tqdm(range(seq_lengths)):
            #for ntau in range(1, self.projection_horizon + 1):
            #    print("ntau=", ntau)
            for nsample in range(nsamples):
                batchd = {}
                fact_length = t + 1
                
                batchd['prev_outputs']       = np.concatenate([dataset_f['inp_x'][:, fact_length - 1:fact_length, :],
                                                               predict_x_cf[:, t, nsample, :-1]], 1)
                batchd['static_features']    = static_cf_features[:, 0, :]  

                # prev
                batchd['prev_treatments']    = np.concatenate([dataset_f["prev_treatments"][:, fact_length - 1:fact_length, :],
                                                          dataset_cf["prev_treatments"][:, t, nsample, :-1]], 1)
                batchd['prev_dosages']       = np.concatenate([dataset_f['inp_d_prev'][:, fact_length - 1:fact_length, :],
                                                          dataset_cf['prev_d_scaled'][:, t, nsample, :-1]], 1)
                # current
                batchd['current_treatments']  = np.concatenate([dataset_f['current_treatments'][:, fact_length - 1:fact_length, :], 
                                                           dataset_cf['current_treatments'][:, t, nsample, :-1]], 1)
                batchd['current_dosages']     = np.concatenate([dataset_f['current_d'][:, fact_length - 1:fact_length, :],
                                                                dataset_cf["current_d_scaled"][:, t, nsample, :-1]], 1)
                
                # 
                batchd['active_entries_cf']     = np.concatenate([dataset_f['active_entries'][:, fact_length - 1:fact_length, :],
                                                                  dataset_cf['active_entries'][:, t, nsample, :-1]], 1)    
                batchd['unscaled_outputs']     = dataset_cf['unscaled_x_nahead'][:, t, nsample, :] 
                
                batchd["active_entries"]     = np.ones((npatients, self.projection_horizon + 1, 1))
                #batchd["seq_lengths"]        = np.ones(npatients) * self.projection_horizon
                
                if (t == 0) and (nsample == 0):
                    data_seq = batchd
                else:
                    for key, value in batchd.items():
                        data_seq[key] = np.concatenate([data_seq[key], value], 0)
                           
        #data_seq["active_entries"] = np.ones_like(data_seq['prev_outputs'] )
        
        #
        # auto regressive
        #
        predicted_outputs = self.get_autoregressive_predictions(data_seq)
        
        # caluculate Normalized RMSE  
        train_means, train_stds = self.train_scaling_params
        predicted_outputs       = predicted_outputs * train_stds["cancer_volume"] + train_means["cancer_volume"]
           
        rmses = {}
        for ntau in range(self.projection_horizon):
            mse_cf  = ((predicted_outputs[:, ntau, :] - data_seq["unscaled_outputs"][:, ntau + 1, :]) ** 2) * data_seq['active_entries_cf'][:,ntau, :]
            mse_cf_all  = mse_cf.sum() / data_seq['active_entries_cf'][:, ntau, :].sum()
            rmse_cf_all = 100.0 * (np.sqrt(mse_cf_all) / self.MAX_CANCER_VOLUME)
            
            rmses[ntau + 2] = rmse_cf_all
                        
        return rmses
                        
