import numpy as np
from copy import deepcopy
from omegaconf import DictConfig

class MSM():
    model_type = None  # Will be defined in subclasses
    possible_model_types = {'msm_regressor', 'propensity_treatment', 'propensity_history'}
    tuning_criterion = None

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: dict,
                 **kwargs):

        # check
        if args.model.name != "MSM":
            print("Model mistach")
            raise Exception()
        
        self.model_name = args.model.name
        super().__init__()
    
        self.train_f = self.set_data(deepcopy(dataset_collection['train_f']))
        self.valid_f = self.set_data(deepcopy(dataset_collection['valid_f']))
        self.test_cf = self.set_data(deepcopy(dataset_collection['test_cf']))
        self.test_cf_multi = self.set_data_multi(deepcopy(dataset_collection['test_cf_multi']))
        
        # scaling params
        self.train_scaling_params = dataset_collection['train_scaling_params']
    
    
        # parameter
        self.lag_features        = args.model.lag_features
        #
        self.dim_treatments      = 2
        self.dim_dosages         = args.dataset.dim_dosages
        self.dim_static_features = args.dataset.dim_static_features
        self.dim_outcome         = args.dataset.dim_outcomes
        
        # projection_horizon
        self.projection_horizon  = args.dataset.projection_horizon
        self.max_length          = args.dataset.max_seq_length
        
        # max cancer volume
        self.MAX_CANCER_VOLUME    = args.dataset.max_cancer_volume
        
    # -------------------------------------------------------------------------------
    # dataset collection
    # -------------------------------------------------------------------------------   
    def set_data(self, dataset):
        
        num_patients, max_seq_length, num_features = dataset["current_iw"].shape
        
        current_treatments = np.zeros([num_patients, max_seq_length, 2])
        prev_treatments = np.zeros([num_patients, max_seq_length, 2])
        #
        for i in range(num_patients):
            for t in range(max_seq_length):
                current_iw = dataset["current_iw"][i,t]
                if current_iw == 0:
                    current_treatments[i, t] = [0, 0]
                elif current_iw == 1:
                    current_treatments[i, t] = [0, 1]                    
                elif current_iw == 2:
                    current_treatments[i, t] = [1, 0]                 
                elif current_iw == 3:
                    current_treatments[i, t] = [1, 1]
                else:
                    raise NotImplementedError()
             
        prev_treatments[:, 1:, :]     = deepcopy(current_treatments[:, :-1, :]) 
        dataset["prev_treatments"]    = prev_treatments
        dataset["current_treatments"] = current_treatments
        
        return dataset
    
    def set_data_multi(self, dataset):
        
        num_patients, max_seq_length, nsamples, ntaus, nfeatures  = dataset["current_iw"].shape
        
        current_treatments = np.zeros([num_patients, max_seq_length, nsamples, ntaus, 2])
        prev_treatments = np.zeros([num_patients, max_seq_length, nsamples, ntaus, 2])
        #
        for i in range(num_patients):
            for t in range(max_seq_length):
                for nsample in range(nsamples):
                    for ntau in range(ntaus):
                        current_iw = dataset["current_iw"][i, t, nsample, ntau]
                        if current_iw == 0:
                            current_treatments[i, t, nsample, ntau] = [0, 0]
                        elif current_iw == 1:
                            current_treatments[i, t, nsample, ntau] = [0, 1]                    
                        elif current_iw == 2:
                            current_treatments[i, t, nsample, ntau] = [1, 0]                 
                        elif current_iw == 3:
                            current_treatments[i, t, nsample, ntau] = [1, 1]
                        else:
                            raise NotImplementedError()
             
        prev_treatments[:, :, :, 1:]     = deepcopy(current_treatments[:, :, :, :-1]) 
        dataset["prev_treatments"]    = prev_treatments
        dataset["current_treatments"] = current_treatments
        
        return dataset
    
    # -------------------------------------------------------------------
    # get propensity scores
    # -------------------------------------------------------------------
    def get_propensity_scores(self, dataset: dict) -> np.array:
        # get exploded dataset
        exploded_dataset = self.get_exploded_dataset(dataset, 
                                                     min_length = self.lag_features, 
                                                     only_active_entries = False)

        # inputs
        inputs = self.get_inputs(exploded_dataset)
       
        # classifier
        classifier = getattr(self, self.model_type)
        
        # calc propensity scores
        propensity_scores = np.stack(classifier.predict_proba(inputs), 1)[:, :, 1]
        
        propensity_scores = propensity_scores.reshape(dataset['active_entries'].shape[0],
                                                      dataset['active_entries'].shape[1] - self.lag_features,
                                                      self.dim_treatments)
        propensity_scores = np.concatenate([0.5 * np.ones((propensity_scores.shape[0], 
                                                           self.lag_features, 
                                                           self.dim_treatments)),
                                            propensity_scores], 
                                           axis=1)
        #
        return propensity_scores

    # -------------------------------------------------------------------
    # Exploded dataset
    # -------------------------------------------------------------------
    def get_exploded_dataset(self, 
                             dataset: dict, 
                             min_length: int, 
                             only_active_entries = True, 
                             max_length = None):
        
        # data copy
        exploded_dataset = deepcopy(dataset)
        if max_length is None:
            max_length = max(exploded_dataset['seq_lengths'][:])
        if not only_active_entries:
            exploded_dataset['active_entries'][:, :, :] = 1.0
            exploded_dataset['seq_lengths'][:] = self.max_length
            
        # explode trajectories
        exploded_dataset = self.explode_trajectories(exploded_dataset, min_length)
        return exploded_dataset
    
    def explode_trajectories(self, data, projection_horizon):
        # --------------------------------------------------------
        # Explode
        # --------------------------------------------------------          
        # input 
        prev_outputs        = data["inp_x"]
        previous_treatments = data["prev_treatments"]
        previous_dosages    = data["inp_d_prev"]
        static_features     = data['inp_v']

        # output    
        outputs             = data['out_x_next']
        unscaled_outputs    = data["unscaled_x_next"]
        current_treatments  = data["current_treatments"]
        current_dosages     = data["current_d"]
        
        # other
        sequence_lengths    = data['seq_lengths']
        active_entries      = data['active_entries']

        if 'stabilized_weights' in data:
            stabilized_weights = data['stabilized_weights']

        num_patients, max_seq_length, num_features = outputs.shape
        num_seq2seq_rows = num_patients * max_seq_length

        #
        seq2seq_previous_treatments = np.zeros((num_seq2seq_rows, max_seq_length, previous_treatments.shape[-1]))
        seq2seq_current_treatments = np.zeros((num_seq2seq_rows, max_seq_length, current_treatments.shape[-1]))
        #
        seq2seq_previous_dosages = np.zeros((num_seq2seq_rows, max_seq_length, previous_dosages.shape[-1]))
        seq2seq_current_dosages  = np.zeros((num_seq2seq_rows, max_seq_length, current_dosages.shape[-1]))
        #
        seq2seq_static_features = np.zeros((num_seq2seq_rows, static_features.shape[-1]))
        seq2seq_outputs = np.zeros((num_seq2seq_rows, max_seq_length, outputs.shape[-1]))
        seq2seq_unscaled_outputs = np.zeros((num_seq2seq_rows, max_seq_length, unscaled_outputs.shape[-1]))
        seq2seq_prev_outputs = np.zeros((num_seq2seq_rows, max_seq_length, prev_outputs.shape[-1]))
        #
        seq2seq_active_entries = np.zeros((num_seq2seq_rows, max_seq_length, active_entries.shape[-1]))
        seq2seq_sequence_lengths = np.zeros(num_seq2seq_rows)
        if 'stabilized_weights' in data:
            seq2seq_stabilized_weights = np.zeros((num_seq2seq_rows, max_seq_length))

        total_seq2seq_rows = 0  # we use this to shorten any trajectories later

        for i in range(num_patients):
            sequence_length = int(sequence_lengths[i])

            for t in range(projection_horizon, sequence_length):  # shift outputs back by 1
                # treatments
                seq2seq_previous_treatments[total_seq2seq_rows, :(t + 1), :] = previous_treatments[i, :(t + 1), :]
                seq2seq_current_treatments[total_seq2seq_rows, :(t + 1), :] = current_treatments[i, :(t + 1), :]
                # dosages
                seq2seq_previous_dosages[total_seq2seq_rows, :(t + 1), :] = previous_dosages[i, :(t + 1), :]
                seq2seq_current_dosages[total_seq2seq_rows, :(t + 1), :] = current_dosages[i, :(t + 1), :]
                # features
                seq2seq_static_features[total_seq2seq_rows] = static_features[i, 0, :]  
                seq2seq_outputs[total_seq2seq_rows, :(t + 1), :] = outputs[i, :(t + 1), :]
                seq2seq_unscaled_outputs[total_seq2seq_rows, :(t + 1), :] =  unscaled_outputs[i, :(t + 1), :]
                seq2seq_prev_outputs[total_seq2seq_rows, :(t + 1), :] = prev_outputs[i, :(t + 1), :]
                
                #
                seq2seq_active_entries[total_seq2seq_rows, :(t + 1), :] = active_entries[i, :(t + 1), :]
                seq2seq_sequence_lengths[total_seq2seq_rows] = t + 1

                if 'stabilized_weights' in data:
                    seq2seq_stabilized_weights[total_seq2seq_rows, :(t + 1)] = stabilized_weights[i, :(t + 1)]

                total_seq2seq_rows += 1

        # Filter everything shorter
        seq2seq_previous_treatments = seq2seq_previous_treatments[:total_seq2seq_rows, :, :]
        seq2seq_current_treatments  = seq2seq_current_treatments[:total_seq2seq_rows, :, :]
        #
        seq2seq_previous_dosages = seq2seq_previous_dosages[:total_seq2seq_rows, :, :]
        seq2seq_current_dosages  = seq2seq_current_dosages[:total_seq2seq_rows, :, :]        
        
        seq2seq_static_features     = seq2seq_static_features[:total_seq2seq_rows, :]
        seq2seq_outputs             = seq2seq_outputs[:total_seq2seq_rows, :, :]
        seq2seq_unscaled_outputs    = seq2seq_unscaled_outputs[:total_seq2seq_rows, :, :]
        seq2seq_prev_outputs        = seq2seq_prev_outputs[:total_seq2seq_rows, :, :]
        #
        seq2seq_active_entries      = seq2seq_active_entries[:total_seq2seq_rows, :, :]
        seq2seq_sequence_lengths    = seq2seq_sequence_lengths[:total_seq2seq_rows]

        if 'stabilized_weights' in data:
            seq2seq_stabilized_weights = seq2seq_stabilized_weights[:total_seq2seq_rows]

        new_data = {
            # treatments
            'prev_treatments'   : seq2seq_previous_treatments,
            'current_treatments': seq2seq_current_treatments,
            # dosages
            'prev_dosages'      : seq2seq_previous_dosages,
            'current_dosages'   : seq2seq_current_dosages,            
            #
            'static_features'   : seq2seq_static_features,
            'outputs'           : seq2seq_outputs,
            'unscaled_outputs'  : seq2seq_unscaled_outputs,
            'prev_outputs'      : seq2seq_prev_outputs,
            #
            'active_entries'    : seq2seq_active_entries,
            'sequence_lengths'  : seq2seq_sequence_lengths
        }
        
        if 'stabilized_weights' in data:
            new_data['stabilized_weights'] = seq2seq_stabilized_weights

        return new_data

        
    

