from pytorch_lightning import LightningModule
import torch
from torch import nn
from copy import deepcopy
from tqdm import tqdm
import math
import torch.optim as optim

import torch.nn.functional as F
from torch.nn.functional import binary_cross_entropy_with_logits

# Dataset
from utils import CancerDataset
from torch.utils.data import DataLoader

class TimeVaryingCausalModel(LightningModule):
    def __init__(self):
        super().__init__()
        print(self.model_name, " is set")

    def bce(self, treatment_pred, current_treatments, weights = None):
        mode = self.treatment_mode
        if mode == 'multiclass':
            return F.cross_entropy(treatment_pred.permute(0, 2, 1), 
                                   current_treatments.permute(0, 2, 1), 
                                   reduce = False, 
                                   weight = weights)
        
        elif mode == 'multilabel':
            return F.binary_cross_entropy_with_logits(treatment_pred, 
                                                      current_treatments, 
                                                      reduce=False, 
                                                      weight=weights).mean(dim=-1)
        else:
            raise NotImplementedError()

    def get_bce_loss(self, treatment_pred, current_treatments, kind='predict'):
        mode = self.treatment_mode
        bce_weights = None
        #
        if kind == 'predict':
            bce_loss = self.bce(treatment_pred, current_treatments, bce_weights)
        elif kind == 'confuse':
            uniform_treatments = torch.ones_like(current_treatments)
            if mode == 'multiclass':
                uniform_treatments *= 1 / current_treatments.shape[-1]
            elif mode == 'multilabel':
                uniform_treatments *= 0.5
            bce_loss = self.bce(treatment_pred, uniform_treatments)
        else:
            raise NotImplementedError()
            
        return bce_loss
        
    # --------------------------------------------------------------------------------------
    # encode factual dataset
    # --------------------------------------------------------------------------------------
    def encode_factual_dataset(self, dataset):
        # dataloader
        data_loader = DataLoader(CancerDataset(dataset), 
                                 batch_size = 1000, shuffle = False)

        for i, batch in enumerate(data_loader): 
            with torch.inference_mode():
                # predict
                if self.isDecoder:
                    predicts = self.encoder(batch)
                else:
                    predicts = self(batch)
                    
                outcome_pred = predicts[0]
                encoder_br   = predicts[1]
                
                if i == 0:
                    outcome_preds = outcome_pred
                    encoder_brs   = encoder_br
                else:
                    outcome_preds = torch.cat([outcome_preds, outcome_pred], dim = 0)
                    encoder_brs   = torch.cat([encoder_brs, encoder_br], dim = 0)                   
        
        assert dataset['inp_x'].shape[0] == outcome_preds.shape[0]
        
        
        return outcome_preds, encoder_brs
        
    # --------------------------------------------------------------------------------------
    # Normalized Masked RMSE Factual
    # --------------------------------------------------------------------------------------      
    def get_one_step_factual_rmse(self, dataset_name = "valid_f"):
        # set data
        if dataset_name == 'train_f':
            dataset_f  = deepcopy(self.train_f)
        elif dataset_name == 'valid_f':
            dataset_f  = deepcopy(self.valid_f)
        elif dataset_name == 'test_cf':
            dataset_f  = deepcopy(self.test_cf)
        else:
            print("Error")
            raise Exception()
        
        # one-step prediction (factual treatment)
        outcome_pred, encoder_br = self.encode_factual_dataset(dataset_f)
        
        # unscaled       
        train_means, train_stds = self.train_scaling_params
        outcome_pred = outcome_pred * train_stds["cancer_volume"] + train_means["cancer_volume"]
                        
        # calculate mse loss
        mse     = ((outcome_pred - dataset_f["unscaled_x_next"]) ** 2) * dataset_f['active_entries']
        mse_all = mse.sum() / dataset_f['active_entries'].sum()
        rmse_normalised_all = 100.0 * (torch.sqrt(mse_all) / self.MAX_CANCER_VOLUME)

        return rmse_normalised_all.to('cpu').detach().numpy().copy()
      
    # --------------------------------------------------------------------------------------
    # Normalized Masked RMSE Factual multi
    # --------------------------------------------------------------------------------------
    def get_multi_step_factual_rmse(self, dataset_name = "valid_f"):
        # set data
        if dataset_name == 'train_f':
            dataset_f  = deepcopy(self.train_f)
        elif dataset_name == 'valid_f':
            dataset_f  = deepcopy(self.valid_f)
        elif dataset_name == 'test_cf':
            dataset_f  = deepcopy(self.test_cf)
        else:
            print("Error")
            raise Exception()
            
        # shape
        npatients, seq_lengths, _ = dataset_f['inp_x'].shape
        predict_x = torch.zeros(npatients, seq_lengths, self.projection_horizon + 1, 1).to(self.device_ori)
                
        # one-step prediction (factual treatment)
        outcome_pred, encoder_br = self.encode_factual_dataset(dataset_f)
        
        predict_x[:, :, 0, 0] = torch.squeeze(outcome_pred)

        # multi-step prediction (factual treatment)
        for t in range(seq_lengths - self.projection_horizon):
            for ntau in range(1, self.projection_horizon + 1):
                batchd = {}
                fact_length = t + 1

                batchd['inp_x']       = predict_x[:, t, :ntau]
                batchd['inp_v']       = dataset_f['inp_v'][:, fact_length:fact_length + ntau]  
                batchd['inp_w_prev']  = dataset_f['inp_w_prev'][:, fact_length:fact_length + ntau]
                batchd['inp_d_prev']  = dataset_f['inp_d_prev'][:, fact_length:fact_length + ntau]
                batchd['active_entries'] = dataset_f['active_entries'][:, fact_length:fact_length + ntau, :]
                batchd['current_ow']  = dataset_f['current_ow'][:, fact_length:fact_length + ntau]
                
                # for CT, CRN
                batchd['current_d']   = dataset_f['current_d'][:, fact_length:fact_length + ntau]
                
                # for TS
                batchd['current_iw']  = dataset_f['current_iw'][:, fact_length:fact_length + ntau]
                batchd['current_od']  = dataset_f['current_od'][:, fact_length:fact_length + ntau]
                
                # for RMSN
                if self.model_name == "RMSN":
                    batchd["prev_treatments"]    = dataset_f['prev_treatments'][:, fact_length:fact_length + ntau] 
                    batchd["current_treatments"] = dataset_f['current_treatments'][:, fact_length:fact_length + ntau]  
                

                if self.isDecoder:
                    # 
                    if self.base_model == "lstm":
                        batchd['init_state'] = encoder_br[:, t, :]                     
                    elif self.base_model == "transformer":
                        active_entries_br = torch.zeros((npatients, seq_lengths)).to(self.device_ori)
                        active_entries_br[:, :fact_length]    = 1.0
                        batchd['active_encoder_br'] = active_entries_br
                        batchd['encoder_br']        = encoder_br
                    else:
                        raise NotImplementedError()
                else:
                    # 
                    batchd['inp_x']      = torch.cat([dataset_f['inp_x'][:, :fact_length, :], batchd['inp_x']], dim = 1)
                    batchd['inp_v']      = dataset_f['inp_v'][:, :fact_length + ntau, :]  
                    #
                    batchd['inp_w_prev'] = torch.cat([dataset_f['inp_w_prev'][:, :fact_length, :], batchd['inp_w_prev']], dim = 1)
                    batchd['inp_d_prev'] = torch.cat([dataset_f['inp_d_prev'][:, :fact_length, :], batchd['inp_d_prev']], dim = 1)
                    batchd['active_entries'] = torch.cat([dataset_f['active_entries'][:, :fact_length, :], batchd['active_entries']], dim = 1)
                    # 
                    batchd['current_ow'] = torch.cat([dataset_f['current_ow'][:, :fact_length, :],batchd['current_ow']], dim = 1)
                    
                    # for CT, CRN
                    batchd['current_d']  = torch.cat([dataset_f['current_d'][:, :fact_length, :],batchd['current_d']], dim = 1)
                    # for TS
                    batchd['current_iw'] = torch.cat([dataset_f['current_iw'][:, :fact_length, :],batchd['current_iw']], dim = 1)
                    batchd['current_od'] = torch.cat([dataset_f['current_od'][:, :fact_length, :],batchd['current_od']], dim = 1)
                    # for RMSN
                    if self.model_name == "RMSN":
                        batchd['prev_treatments'] = torch.cat([dataset_f['prev_treatments'][:, :fact_length, :],batchd['prev_treatments']], dim = 1)
                        batchd['current_treatments'] = torch.cat([dataset_f['current_treatments'][:, :fact_length, :],batchd['current_treatments']], dim = 1)
                    
                # prediction
                with torch.inference_mode():
                    results = self(batchd)
                    
                outcome_pred = results[0]
                predict_x[:, t, ntau, 0] = torch.squeeze(outcome_pred[:, -1])
                
        # caluculate Normalized RMSE           
        train_means, train_stds = self.train_scaling_params
        predict_x = predict_x * train_stds["cancer_volume"] + train_means["cancer_volume"]
        
        end_t = seq_lengths - self.projection_horizon
        
        rmse_normalised_alls = {}
        for tau in range(0, self.projection_horizon + 1):
            predict_x_next = predict_x[:, :end_t, tau, 0]
            out_x_next     = dataset_f["unscaled_x_next"][:, tau:end_t + tau, 0]
            active_entries = dataset_f["active_entries"][:, tau:end_t + tau, 0]
            
            # calculate mse loss
            mse = ((predict_x_next - out_x_next) ** 2) * active_entries
            mse_all = mse.sum() / active_entries.sum()
            rmse_normalised_all = 100.0 * (torch.sqrt(mse_all) / self.MAX_CANCER_VOLUME)
            
            # store
            rmse_normalised_alls[tau + 1] = rmse_normalised_all.to('cpu').detach().numpy().copy()
            
        return rmse_normalised_alls, predict_x_next.to('cpu').detach().numpy().copy()
    
    # ---------------------------------------------------------------------------------------
    # Normalized Masked RMSE CouterFactual multi
    # ---------------------------------------------------------------------------------------
    def get_multi_step_counterfactual_rmse(self):
        train_means, train_stds = self.train_scaling_params
        
        # set data
        dataset_f  = self.test_cf
        dataset_cf = self.test_cf_multi
        
        # shape
        npatients, seq_lengths, nsamples, ntaus, _ = dataset_cf['unscaled_x_nahead'].shape
        predict_x_cf = torch.zeros(npatients, seq_lengths, nsamples, ntaus, 1).to(self.device_ori)
        
        # statci features for cf 
        static_features    = dataset_f['inp_v']
        static_cf_features = deepcopy(static_features)
        static_cf_features = torch.cat([static_cf_features, static_cf_features[:, 0:ntaus-1, :]], 
                                       dim = 1).to(self.device_ori)
        
        # active_entries_next_cf
        active_entries_cf = dataset_cf["active_entries"]
        
        # one-step prediction (factual treatment) 
        outcome_pred, encoder_br = self.encode_factual_dataset(dataset_f)
                
        for nsample in range(nsamples):
            predict_x_cf[:, :, nsample, 0, 0] = torch.squeeze(outcome_pred)
        
        # multi-step prediction (couterfactual treatment)
        for t in tqdm(range(seq_lengths)):
            for ntau in range(1, self.projection_horizon + 1):
                for nsample in range(nsamples):
                    batchd = {}
                    fact_length = t + 1

                    batchd['inp_x']          = predict_x_cf[:, t, nsample, :ntau]
                    batchd['inp_v']          = static_cf_features[:, fact_length:fact_length + ntau]  
                    #
                    batchd['inp_w_prev']     = dataset_cf['prev_ow'][:, t, nsample, :ntau]
                    batchd['inp_d_prev']     = dataset_cf['prev_d_scaled'][:, t, nsample, :ntau]
                    batchd['active_entries'] = dataset_cf["active_entries"][:, t, nsample, :ntau]
                    batchd['current_ow']     = dataset_cf['current_ow'][:, t, nsample, :ntau, :]

                    # for CT, CRN
                    batchd['current_d']      = dataset_cf['current_d_scaled'][:, t, nsample, :ntau, :]   

                    # for TS
                    batchd['current_od']     = dataset_cf['current_od'][:, t, nsample, :ntau, :]
                    batchd['current_iw']     = dataset_cf['current_iw'][:, t, nsample, :ntau, :] 
                    
                    # for RMSN
                    if self.model_name == "RMSN":
                        batchd["prev_treatments"]    = dataset_cf['prev_treatments'][:, t, nsample, :ntau, :] 
                        batchd["current_treatments"] = dataset_cf['current_treatments'][:, t, nsample, :ntau, :]  
                    
                    
                    if self.isDecoder:
                        if self.base_model == "lstm":
                            batchd['init_state'] = encoder_br[:, t, :]                  
                        elif self.base_model == "transformer":
                            active_entries_br = torch.zeros((npatients, seq_lengths)).to(self.device_ori)
                            active_entries_br[:, :fact_length]    = 1.0
                            batchd['active_encoder_br'] = active_entries_br
                            batchd['encoder_br']        = encoder_br
                    else:
                        batchd['inp_x']          = torch.cat([dataset_f['inp_x'][:, :fact_length, :], batchd['inp_x']], dim = 1)
                        batchd['inp_v']          = static_cf_features[:, :fact_length + ntau, :]  
                        batchd['inp_w_prev']     = torch.cat([dataset_f['inp_w_prev'][:, :fact_length, :], batchd['inp_w_prev']], dim = 1)
                        batchd['inp_d_prev']     = torch.cat([dataset_f['inp_d_prev'][:, :fact_length, :], batchd['inp_d_prev']], dim = 1)
                        batchd['active_entries'] = torch.cat([dataset_f['active_entries'][:, :fact_length, :], batchd['active_entries']], dim = 1) 
                        batchd['current_ow']     = torch.cat([dataset_f['current_ow'][:, :fact_length, :], batchd['current_ow']], dim = 1)
                        # for CT, CRN
                        batchd['current_d']  = torch.cat([dataset_f['current_d'][:, :fact_length, :], batchd['current_d']], dim = 1)
                        # for TS
                        batchd['current_od']  = torch.cat([dataset_f['current_od'][:, :fact_length, :], batchd['current_od']], dim = 1)
                        batchd['current_iw']  = torch.cat([dataset_f['current_iw'][:, :fact_length], batchd['current_iw']], dim = 1) 
                        
                        # for RMSN
                        if self.model_name == "RMSN":
                            batchd['prev_treatments']  = torch.cat([dataset_f['prev_treatments'][:, :fact_length, :], batchd['prev_treatments']], dim = 1)
                            batchd['current_treatments']  = torch.cat([dataset_f['current_treatments'][:, :fact_length], batchd['current_treatments']], dim = 1)                         

                    # prediction
                    with torch.inference_mode():
                        results = self(batchd)
                        
                    outcome_pred = results[0]                    
                    predict_x_cf[:, t, nsample, ntau] = outcome_pred[:, -1]
                    
        # caluculate Normalized RMSE      
        predict_x_cf = predict_x_cf * train_stds["cancer_volume"] + train_means["cancer_volume"]
           
        rmses = {}
        for ntau in range(self.projection_horizon + 1):
            mse_cf  = ((predict_x_cf[:, :, :, ntau, :] - dataset_cf["unscaled_x_nahead"][:, :, :, ntau, :]) ** 2) * active_entries_cf[:,:,:, ntau, :]
            mse_cf_all  = mse_cf.sum() / active_entries_cf[:, :, :, ntau, :].sum()
            rmse_cf_all = 100.0 * (torch.sqrt(mse_cf_all) / self.MAX_CANCER_VOLUME)
            
            rmses[ntau + 1] = rmse_cf_all.to('cpu').detach().numpy().copy()
                        
        return rmses, predict_x_cf
    
    # ----------------------------------------------------------------
    # optimizers and schedulers
    # ----------------------------------------------------------------
    def _get_optimizers(self, param_optimizer: list):
        no_decay = ['bias', 'layer_norm']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                'weight_decay': self.weight_decay,
            },
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ]
        lr = self.learning_rate
        optimizer = optim.Adam(optimizer_grouped_parameters, lr=lr)

        return optimizer
                   
    def _get_lr_schedulers(self, optimizer):
        if not isinstance(optimizer, list):
            lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
            return [lr_scheduler]
        else:
            lr_schedulers = []
            for opt in optimizer:
                lr_schedulers.append(optim.lr_scheduler.ExponentialLR(opt, gamma=0.97))
            return lr_schedulers 
        
    # -------------------------------------------------------------------------------------------
    # process_sequential
    # -------------------------------------------------------------------------------------------
    def process_sequential(self, dataset):
        npatients, seq_lengths, _ = dataset['inp_x'].shape
        
        if 'stabilized_weights' in dataset:
            isSW = True
        else:
            isSW = False
        
        # ------------------------------------------------------------------------------------------------
        # one-step prediction (factual treatment)
        # ------------------------------------------------------------------------------------------------
        self.eval()
        outcome_pred, encoder_br = self.encode_factual_dataset(dataset)
        
        #
        dataset_seq = {}
        for t in range(1, seq_lengths - self.projection_horizon):
            inp_x      = dataset['inp_x'][:, t:t + self.projection_horizon, :]
            inp_v      = dataset['inp_v'][:, t:t + self.projection_horizon, :]
            inp_w_prev = dataset['inp_w_prev'][:, t:t + self.projection_horizon, :]
            inp_d_prev = dataset['inp_d_prev'][:, t:t + self.projection_horizon, :]
            out_x_next = dataset['out_x_next'][:, t:t + self.projection_horizon, :]
            current_ow = dataset['current_ow'][:, t:t + self.projection_horizon, :]
            current_od = dataset['current_od'][:, t:t + self.projection_horizon, :]
            current_d  = dataset['current_d'][:, t:t + self.projection_horizon, :]
            current_iw = dataset['current_iw'][:, t:t + self.projection_horizon, :]
            current_id = dataset['current_id'][:, t:t + self.projection_horizon, :]
            #
            active_entries      = dataset['active_entries'][:, t:t + self.projection_horizon, :]    
            if isSW:
                stabilized_weights = dataset['stabilized_weights'][:, t - 1:t + self.projection_horizon] 
                prev_treatments    = dataset["prev_treatments"][:, t:t + self.projection_horizon, :]    
                current_treatments = dataset["current_treatments"][:, t:t + self.projection_horizon, :]    

            unscaled_x              = dataset['unscaled_x'][:, t:t + self.projection_horizon, :]
            counterfactual_mask     = dataset["counterfactual_mask"][:, t:t + self.projection_horizon, :]
            masked_factual_dr_curve = dataset["masked_factual_dr_curve"][:, t:t + self.projection_horizon, :]
                
            # encoder_r
            init_state               = encoder_br[:, t - 1, :] 
            active_entries_br        = torch.zeros((npatients, seq_lengths)).to(self.device_ori)
            active_entries_br[:, :t] = 1.0
            
            # concat
            if t == 1:
                # inp
                dataset_seq['inp_x']               = inp_x
                dataset_seq['inp_v']               = inp_v 
                dataset_seq['inp_w_prev']          = inp_w_prev
                dataset_seq['inp_d_prev']          = inp_d_prev
                # out
                dataset_seq['out_x_next']          = out_x_next
                # current
                dataset_seq['current_ow']          = current_ow
                dataset_seq['current_od']          = current_od
                dataset_seq['current_iw']          = current_iw
                dataset_seq['current_id']          = current_id
                dataset_seq['current_d']           = current_d
                # active entries
                dataset_seq['active_entries']      = active_entries
                
                dataset_seq['unscaled_x']              = unscaled_x
                dataset_seq['counterfactual_mask']     = counterfactual_mask
                dataset_seq['masked_factual_dr_curve'] = masked_factual_dr_curve
                
                if isSW:
                    dataset_seq['stabilized_weights'] = stabilized_weights
                    dataset_seq['prev_treatments']    = prev_treatments
                    dataset_seq['current_treatments'] = current_treatments
                
                # encoder_br
                if self.base_model == "lstm":
                    # init state
                    dataset_seq['init_state']        = init_state                  
                elif self.base_model == "transformer":
                    dataset_seq['active_encoder_br'] = active_entries_br
                    dataset_seq['encoder_br']        = encoder_br
                else:
                    raise NotImplementedError()
                
            else:
                # inp
                dataset_seq['inp_x']               = torch.cat([dataset_seq['inp_x'], inp_x], dim = 0)
                dataset_seq['inp_v']               = torch.cat([dataset_seq['inp_v'] , inp_v], dim = 0)   
                dataset_seq['inp_w_prev']          = torch.cat([dataset_seq['inp_w_prev'], inp_w_prev], dim = 0)
                dataset_seq['inp_d_prev']          = torch.cat([dataset_seq['inp_d_prev'], inp_d_prev], dim = 0)
                # out
                dataset_seq['out_x_next']          = torch.cat([dataset_seq['out_x_next'], out_x_next], dim = 0)
                # current
                dataset_seq['current_ow']          = torch.cat([dataset_seq['current_ow'], current_ow], dim = 0)
                dataset_seq['current_od']          = torch.cat([dataset_seq['current_od'], current_od], dim = 0)
                dataset_seq['current_iw']          = torch.cat([dataset_seq['current_iw'], current_iw], dim = 0)
                dataset_seq['current_id']          = torch.cat([dataset_seq['current_id'], current_id], dim = 0)
                dataset_seq['current_d']           = torch.cat([dataset_seq['current_d'], current_d], dim = 0)
                # active entries
                dataset_seq['active_entries']      = torch.cat([dataset_seq['active_entries'], active_entries], dim = 0)
                
                dataset_seq['unscaled_x']             = torch.cat([dataset_seq['unscaled_x'], unscaled_x], dim = 0)
                dataset_seq['counterfactual_mask']    = torch.cat([dataset_seq['counterfactual_mask'], counterfactual_mask], dim = 0)
                dataset_seq['masked_factual_dr_curve'] = torch.cat([dataset_seq['masked_factual_dr_curve'], masked_factual_dr_curve], dim = 0)
                
                if isSW:
                    dataset_seq['stabilized_weights'] =  torch.cat([dataset_seq['stabilized_weights'], 
                                                                    stabilized_weights], dim = 0)
                    dataset_seq['prev_treatments'] =  torch.cat([dataset_seq['prev_treatments'], 
                                                                    prev_treatments], dim = 0)                    
                    dataset_seq['current_treatments'] =  torch.cat([dataset_seq['current_treatments'], 
                                                                    current_treatments], dim = 0)                   
                

                # encoder_br
                if self.base_model == "lstm":
                    dataset_seq['init_state']        = torch.cat([dataset_seq['init_state'], init_state], dim = 0)                    
                elif self.base_model == "transformer":
                    dataset_seq['active_encoder_br'] = torch.cat([dataset_seq['active_encoder_br'], 
                                                                  active_entries_br], dim = 0)
                    dataset_seq['encoder_br']        = torch.cat([dataset_seq['encoder_br'], encoder_br], dim = 0)
                else:
                    raise NotImplementedError()
                    
            torch.cuda.empty_cache()
                
        return dataset_seq
    