from RMSN import RMSN
from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

import torch
import torch.nn.functional as F

class RMSNPropensityNetworkTreatment(RMSN):        
    model_type = 'propensity_treatment'
    model_name = "RMSN"
    tuning_criterion = 'bce'

    def __init__(self, args: DictConfig,
                 dataset_collection: dict,
                 **kwargs):
        super().__init__(args, dataset_collection)
        
        #
        self.input_size  = self.dim_treatments
        self.output_size = self.dim_treatments

        self._init_specific(args.model.propensity_treatment)

    def forward(self, batch):
        prev_treatments = batch['prev_treatments']        
        x               = self.lstm(prev_treatments, init_states=None)
        treatment_pred  = self.output_layer(x)
        return treatment_pred

    def training_step(self, batch):
        self.train()
        treatment_pred = self(batch)
        bce_loss       = self.get_bce_loss(treatment_pred, 
                                       batch['current_treatments'].double(), 
                                       kind='predict')
        bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
        return bce_loss

    def predict_step(self, batch):
        self.eval()
        return F.sigmoid(self(batch)).cpu()


class RMSNPropensityNetworkHistory(RMSN):
    model_type = 'propensity_history'
    model_name = "RMSN"
    tuning_criterion = 'bce'

    
    def __init__(self, args: DictConfig,
                 dataset_collection: dict,
                 **kwargs):
        super().__init__(args, dataset_collection)
        self.input_size = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome 
        self.output_size = self.dim_treatments

        self._init_specific(args.model.propensity_history)

    def forward(self, batch, detach_treatment=False):
        prev_treatments = batch['prev_treatments']
        prev_dosages    = batch["inp_d_prev"]
        prev_outputs    = batch["inp_x"]
        static_features = batch['inp_v']
        x = torch.cat((prev_treatments, prev_dosages, prev_outputs, static_features), dim=-1)
        x = self.lstm(x, init_states=None)
        treatment_pred = self.output_layer(x)
        return treatment_pred

    def training_step(self, batch):
        self.train()
        treatment_pred = self(batch)
        bce_loss = self.get_bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
        return bce_loss

    def predict_step(self, batch):
        self.eval()
        return F.sigmoid(self(batch)).cpu()