from pytorch_lightning import LightningModule
import torch
from torch import nn
import torch.optim as optim
from copy import deepcopy
from tqdm import tqdm
import math
from torch.utils.data import DataLoader, Dataset, Subset
from utils import CancerDataset

# Learning
from torch_ema import ExponentialMovingAverage

from TimeVaryingCausalModel import TimeVaryingCausalModel

class GANCausalModel(TimeVaryingCausalModel): 
    # ----------------------------------------------------------------------------------------------------
    # forward
    # ----------------------------------------------------------------------------------------------------
    def forward(self, batch, detached = False):
        # build representation
        br = self.buildBR(batch)
        
        #
        num_patients, max_length, _ = br.shape
        D_dosage_logits = [0] * self.num_treatments
        
        # Noise
        z = torch.rand([num_patients, max_length, self.num_treatments, self.num_dosage_samples, 1],
                       device = self.device_ori)
        
        # Dosage
        d = torch.ones([num_patients, max_length, self.num_treatments, self.num_dosage_samples, 1],
                      device = self.device_ori) 
        for iw, dosages in self.dict_wd.items():  
            for i_d, dosage in enumerate(dosages):
                d[:, :, iw, i_d, 0] = d[:, :, iw, i_d, 0] * dosage
                
        #
        # generate Counterfactual
        # 
        dr_curves = self.genCF(br, z, d)
        dr_curves = dr_curves.to(self.device_ori)
        
        if detached:
            br = br.detach()
            dr_curves = dr_curves.detach()

            # replace factual
            dr_curves = dr_curves * batch["counterfactual_mask"] + batch["masked_factual_dr_curve"]

        # ------------------------------------------------------
        # discriminate treatment
        # ------------------------------------------------------        
        D_treatment_logit = self.discTreat(br, dr_curves, d)
        
        # -------------------------------------------------------------------------------------------------------------
        # discriminate dosage
        # -------------------------------------------------------------------------------------------------------------
        # factual dr_curve
        factual_dr_curve, factual_x_next = self.get_factual_dr_curve(dr_curves, batch['current_ow'], batch['current_od'])
        #
        current_iw = batch['current_iw']
        for w, dosages in self.dict_wd.items():
            if w == 0:
                continue
            
            # active mask, tile_dosages
            active_mask_w = (current_iw == w) 
            tile_dosages = torch.Tensor(dosages).repeat(num_patients, max_length, 1).to(self.device_ori)
            
            # 
            D_dosage_logit  = self.discDosages_list[w-1](br, factual_dr_curve, tile_dosages)
            
            # 
            D_dosage_logits[w] = D_dosage_logit * active_mask_w
            
        return factual_x_next, br, D_treatment_logit, D_dosage_logits

    # ----------------------------------------------------------------------------------------------------
    # factual_dr_curve
    # ----------------------------------------------------------------------------------------------------    
    def get_factual_dr_curve(self, dr_curves, current_ow, current_od):
        num_patients, max_length, _ = current_ow.shape
        
        # treatment
        current_w_mask   = (current_ow == 1)
        current_w_mask   = torch.unsqueeze(current_w_mask, dim = -1).repeat(1,1,1,2)
        
        factual_dr_curve = torch.masked_select(dr_curves, current_w_mask).reshape(num_patients, 
                                                                                  max_length, 
                                                                                  self.num_dosage_samples)
        # dosage
        current_d_mask = (current_od == 1)       
        factual_x_next = torch.masked_select(factual_dr_curve, current_d_mask).reshape(num_patients,
                                                                                       max_length, 
                                                                                       1)
        return factual_dr_curve, factual_x_next
    
    # --------------------------------------------------------------------------------------
    # Optimizers
    # -------------------------------------------------------------------------------------- 
    def configure_optimizers(self):
        #
        # cf generator 
        #
        gen_cf_outcomes_param = \
            ['genCF.' + s for s in self.genCF.gencf_head_params]
        gen_cf_outcomes_param = \
            [k for k in dict(self.named_parameters()) for param in gen_cf_outcomes_param if k.startswith(param)]

        #
        # discriminator treatment
        #
        disc_treatment_param = \
            ['discTreat.' + s for s in self.discTreat.treatment_head_params]
        disc_treatment_param = \
            [k for k in dict(self.named_parameters()) for param in disc_treatment_param if k.startswith(param)]

        #
        # discriminator dosage 
        #
        disc_dosage_params = []
        for w in range(self.num_treatments - 1):
            dosage_head_param = \
                [f'discDosages_list.{w}.' + s for s in self.discDosages_list[w].dosage_head_params]
            dosage_head_param = \
                [k for k in dict(self.named_parameters()) for param in dosage_head_param if k.startswith(param)]

            disc_dosage_params.append(dosage_head_param)

        sum_dosage_params = [x for row in disc_dosage_params for x in row]

        # non discriminator
        non_discriminator_params = [k for k in dict(self.named_parameters()) \
                                     if ((k not in disc_treatment_param) and ((k not in sum_dosage_params)))]
        
        # for ema
        discriminator_params = [k for k in dict(self.named_parameters()) \
                                     if ((k in disc_treatment_param) or ((k in sum_dosage_params)))]
        
        transformer_params = [k for k in dict(self.named_parameters()) \
                                     if (  (k not in disc_treatment_param) and (k not in sum_dosage_params) and (k not in gen_cf_outcomes_param)  ) ]

        # check
        assert len(non_discriminator_params + disc_treatment_param + disc_dosage_params[0] + disc_dosage_params[1] + disc_dosage_params[2]) \
        == len(list(self.named_parameters()))
        
        assert len(transformer_params + gen_cf_outcomes_param) == len(non_discriminator_params)
        

        # optimizer params
        disc_treatment_param     = [(k, v) for k, v in dict(self.named_parameters()).items() if k in disc_treatment_param]
        disc_treatment_optimizer = self._get_optimizers(disc_treatment_param)

        non_discriminator_params = [(k, v) for k, v in dict(self.named_parameters()).items() if k in non_discriminator_params]
        non_discriminator_optimizer = self._get_optimizers(non_discriminator_params)


        if self.weights_ema:
            discriminator_params = [(k, v) for k, v in dict(self.named_parameters()).items() if k in discriminator_params]

            self.ema_discriminator = ExponentialMovingAverage([par[1] for par in discriminator_params],
                                                          decay=self.beta)
            self.ema_non_discriminator = ExponentialMovingAverage([par[1] for par in non_discriminator_params],
                                                              decay=self.beta)

        optimizers = [non_discriminator_optimizer, disc_treatment_optimizer]

        # dosage
        disc_dosage_optimizers = []
        for disc_dosage_param in disc_dosage_params:
            disc_dosage_param     = [(k, v) for k, v in dict(self.named_parameters()).items() if k in  disc_dosage_param]
            dosage_head_optimizer = self._get_optimizers(disc_dosage_param)
            disc_dosage_optimizers.append(dosage_head_optimizer)
            optimizers.append(dosage_head_optimizer)

        if self.lr_scheduler:
            lr_scheduler = self._get_lr_schedulers(optimizers)
            return [non_discriminator_optimizer, disc_treatment_optimizer, disc_dosage_optimizers], lr_scheduler

        return [non_discriminator_optimizer, disc_treatment_optimizer, disc_dosage_optimizers]
     
    # --------------------------------------------------------------------------------------
    # DataLoader
    # --------------------------------------------------------------------------------------  
    def train_dataloader(self) -> DataLoader:
        if self.isDecoder:
            return DataLoader(CancerDataset(self.train_f_sequential), 
                              shuffle = True, 
                              batch_size = self.train_batch_size, 
                              drop_last = True)
            
        else:
            return DataLoader(CancerDataset(self.train_f), 
                              shuffle = True, 
                              batch_size = self.train_batch_size, 
                              drop_last = True)

    
    