from pytorch_lightning import LightningModule
import torch
from torch import nn
from copy import deepcopy
from tqdm import tqdm
import math
from TimeVaryingCausalModel import TimeVaryingCausalModel
# Learning
from torch_ema import ExponentialMovingAverage
# Dataset
from utils import CancerDataset
from torch.utils.data import DataLoader, Dataset, Subset

class BRCausalModel(TimeVaryingCausalModel):
    def forward(self, batch, detach = False):
        prev_outputs    = batch['inp_x'].float()
        static_features = batch['inp_v'].float()
        prev_treatments = batch['inp_w_prev'].float()
        curr_treatments = batch['current_ow']
        
        prev_dosages    = batch['inp_d_prev'].float()
        curr_dosages    = batch['current_d']
        
        # lstm
        if self.base_model == "lstm":
            # for Decoder
            if self.isDecoder:
                init_states = batch['init_state']
            else:
                init_states = None
            #
            br = self.buildBr(prev_treatments = prev_treatments,
                              prev_dosages    = prev_dosages,
                              prev_outputs    = prev_outputs,
                              static_features = static_features,
                              init_states     = init_states)
        # transformer
        elif self.base_model == "transformer":
            br = self.buildBr(batch)
        else:
            raise NotImplementedError()
        #
        outcome_pred   = self.brOutcomeHead(br                = br,
                                            current_treatment = curr_treatments, 
                                            current_dosage    = curr_dosages)
        #
        treatment_pred = self.brTreamentHead(br             = br, 
                                             detach         = detach)

        dosage_pred = self.brDosageHead(br = br,
                                        current_treatment = curr_treatments,
                                        detach = detach)
        
        return outcome_pred, br, treatment_pred, dosage_pred
    

    # --------------------------------------------------------------------------------------
    # Optimizers
    # --------------------------------------------------------------------------------------            
    def configure_optimizers(self):
        if self.balancing == 'grad_reverse' and not self.weights_ema:  # one optimizer
            optimizer = self._get_optimizers(list(self.named_parameters()))

            if self.lr_scheduler:
                lr_scheduler = self._get_lr_schedulers(optimizer)
                return optimizer, lr_scheduler
            
            return optimizer

        else:  # two optimizers - simultaneous gradient descent update
            treatment_head_params = \
                ['br_treatment_outcome_head.' + s for s in self.brTreamentHead.treatment_head_params]
            treatment_head_params = \
                [k for k in dict(self.named_parameters()) for param in treatment_head_params if k.startswith(param)]
            
            dosage_head_params = \
                ['br_dosage_outcome_head.' + s for s in self.brDosageHead.dosage_head_params]
            dosage_head_params = \
                [k for k in dict(self.named_parameters()) for param in dosage_head_params if k.startswith(param)]
            

            non_discriminator_params = [k for k in dict(self.named_parameters()) \
                                        if ((k not in treatment_head_params) and ((k not in dosage_head_params)))]
            # for ema
            discriminator_params = [k for k in dict(self.named_parameters()) \
                                    if ((k in treatment_head_params) or ((k in dosage_head_params)))]

            # check
            assert len(non_discriminator_params + treatment_head_params + dosage_head_params) == len(list(self.named_parameters()))               


            treatment_head_params     = [(k, v) for k, v in dict(self.named_parameters()).items() if k in treatment_head_params]
            treatment_head_optimizer = self._get_optimizers(treatment_head_params)

            dosage_head_params     = [(k, v) for k, v in dict(self.named_parameters()).items() if k in dosage_head_params]
            dosage_head_optimizer = self._get_optimizers(dosage_head_params)

            non_discriminator_params = [(k, v) for k, v in dict(self.named_parameters()).items() if k in non_discriminator_params]
            non_discriminator_optimizer = self._get_optimizers(non_discriminator_params)

            # for ema
            discriminator_params = [(k, v) for k, v in dict(self.named_parameters()).items() if k in discriminator_params]

            if self.weights_ema:
                self.ema_discriminator = ExponentialMovingAverage([par[1] for par in discriminator_params],
                                                              decay=self.beta)
                self.ema_non_discriminator = ExponentialMovingAverage([par[1] for par in non_discriminator_params],
                                                                  decay=self.beta)

            optimizers = [non_discriminator_optimizer, treatment_head_optimizer,  dosage_head_optimizer]

            if self.lr_scheduler:
                lr_scheduler = self._get_lr_schedulers(optimizers)
                return [non_discriminator_optimizer, treatment_head_optimizer, dosage_head_optimizer], lr_scheduler

            return [non_discriminator_optimizer, treatment_head_optimizer, dosage_head_optimizer]
                               
    # --------------------------------------------------------------------------------------
    # DataLoader
    # --------------------------------------------------------------------------------------  
    def train_dataloader(self) -> DataLoader:
        if self.isDecoder:
            return DataLoader(CancerDataset(self.train_f_sequential), 
                              shuffle = True, 
                              batch_size = self.train_batch_size, 
                              drop_last = True)
            
        else:
            return DataLoader(CancerDataset(self.train_f), 
                              shuffle = True, 
                              batch_size = self.train_batch_size, 
                              drop_last = True)
        
    