import numpy as np
import torch
import torch.nn.functional as F
from copy import deepcopy

from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

from RMSN import RMSN
from RMSNPropensityNetwork import RMSNPropensityNetworkTreatment, RMSNPropensityNetworkHistory

class RMSNEncoder(RMSN):
    model_type = 'encoder'
    model_name = "RMSN"
    tuning_criterion = 'rmse'
    isDecoder = False
    

    def __init__(self, args: DictConfig,
                 propensity_treatment: RMSNPropensityNetworkTreatment = None,
                 propensity_history: RMSNPropensityNetworkHistory = None,
                 dataset_collection: dict = None,
                 **kwargs):
        super().__init__(args, dataset_collection)
        
        self.input_size  = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome 
        self.output_size = self.dim_outcome

        self.propensity_treatment = propensity_treatment
        self.propensity_history = propensity_history

        self._init_specific(args.model.encoder)
        
        # stabilized weightsの計算
        self.prepare_data()

    def prepare_data(self) -> None:
        # Datasets normalisation etc.
        if 'sw_tilde_enc' not in self.train_f:
            #
            self.process_propensity_train_f(self.propensity_treatment, self.propensity_history)
            #
            self.train_f['sw_tilde_enc'] = self.clip_normalize_stabilized_weights(
                self.train_f['stabilized_weights'],
                self.train_f['active_entries'],
                multiple_horizons=False)

    def forward(self, batch, detach_treatment=False):
        curr_treatments = batch['current_treatments']
        curr_dosages    = batch['current_d']
        static_features = batch['inp_v']
        prev_outputs    = batch["inp_x"]
        #
        x = torch.cat((curr_treatments, curr_dosages, static_features, prev_outputs), dim = -1)
        r = self.lstm(x, init_states=None)
        outcome_pred = self.output_layer(r)
        return outcome_pred, r

    def training_step(self, batch):
        self.train()
        outcome_pred, _   = self(batch)
        mse_loss          = F.mse_loss(outcome_pred, batch['out_x_next'], reduce=False)
        weighted_mse_loss = mse_loss * batch['sw_tilde_enc'].unsqueeze(-1)
        weighted_mse_loss = (batch['active_entries'] * weighted_mse_loss).sum() / batch['active_entries'].sum()
        return weighted_mse_loss

    def predict_step(self, batch, batch_ind, dataset_idx=None):
        self.eval()
        outcome_pred, r = self(batch)
        return outcome_pred.cpu(), r.cpu()
    
    
    
    """
    def get_representations(self, dataset: Dataset) -> np.array:
        logger.info(f'Representations inference for {dataset.subset_name}.')
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False)
        _, r = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return r.numpy()

    def get_predictions(self, dataset: Dataset) -> np.array:
        logger.info(f'Predictions for {dataset.subset_name}.')
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False)
        outcome_pred, _ = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return outcome_pred.numpy()
    """


class RMSNDecoder(RMSN):
    model_type = 'decoder'
    model_name = "RMSN"
    tuning_criterion = 'rmse'
    isDecoder = True
    base_model = "lstm"
       
    def __init__(self, args: DictConfig,
                 encoder: RMSNEncoder = None,
                 dataset_collection: dict = None,
                 encoder_r_size: int = None,
                 **kwargs):
        
        super().__init__(args, dataset_collection)
        self.encoder = encoder
        self.input_size  = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome 
        self.output_size = self.dim_outcome


        encoder_r_size = self.encoder.seq_hidden_units if encoder is not None else encoder_r_size

        self._init_specific(args.model.decoder, encoder_r_size=encoder_r_size)
        
        # stabilized weightsの計算
        self.prepare_data()

    def prepare_data(self) -> None:
        self.train_f = self.process_sequential(deepcopy(self.train_f))
        
        if 'sw_tilde_dec' not in self.train_f:
            self.train_f['stabilized_weights'] = \
                torch.cumprod(self.train_f['stabilized_weights'], axis = -1)[:, 1:]

            self.train_f['sw_tilde_dec'] = self.clip_normalize_stabilized_weights(
                self.train_f['stabilized_weights'],
                self.train_f['active_entries'],
                multiple_horizons = True)

    def forward(self, batch, detach_treatment=False):
        curr_treatments  = batch['current_treatments']
        curr_dosages    = batch['current_d']
        prev_outputs     = batch["inp_x"]
        static_features  = batch['inp_v']
            
        #
        init_states      = batch['init_state']
        x = torch.cat((curr_treatments, curr_dosages, static_features, prev_outputs), dim = -1)
        x = self.lstm(x, init_states = self.memory_adapter(init_states))
        outcome_pred = self.output_layer(x)
        return outcome_pred, None

    def training_step(self, batch):
        self.train()
        outcome_pred, _ = self(batch)
        mse_loss = F.mse_loss(outcome_pred, batch['out_x_next'], reduce=False)
        weighted_mse_loss = mse_loss * batch['sw_tilde_dec'].unsqueeze(-1)
        weighted_mse_loss = (batch['active_entries'] * weighted_mse_loss).sum() / batch['active_entries'].sum()
        return weighted_mse_loss

    def predict_step(self, batch, batch_ind, dataset_idx=None):
        self.eval()
        return self(batch).cpu()

    def get_predictions(self, dataset: dict) -> np.array:
        logger.info(f'Predictions for {dataset.subset_name}.')
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False)
        outcome_pred = torch.cat(self.trainer.predict(self, data_loader))
        return outcome_pred.numpy()


        