import numpy as np
import math

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.functional import binary_cross_entropy_with_logits
from copy import deepcopy

# Dataset
from torch_ema import ExponentialMovingAverage
from pytorch_lightning import seed_everything

from TimeVaryingCausalModel import TimeVaryingCausalModel
from GANCausalModel import GANCausalModel 

#######################################################################################
# Pytorch Lighting Modules                                                            #
#######################################################################################
class GANTrain(GANCausalModel):
    def __init__(self, args):
        super().__init__()
        self.init_specific(args)
        
        # off automatic optimization
        self.automatic_optimization = False
           
    def init_specific(self, args):
        self.num_treatments     = args.dataset.dim_treatments 
        self.num_dosage_samples = args.dataset.dim_n_dosage_samples
                    
        # max cancer volume
        self.MAX_CANCER_VOLUME    = args.dataset.max_cancer_volume
        
        # projection_horizon
        self.projection_horizon   = args.dataset.projection_horizon
            
        # 
        self.dim_treatments      = args.dataset.dim_treatments
        self.dim_dosages         = args.dataset.dim_dosages
        self.dim_static_features = args.dataset.dim_static_features
        self.dim_outcome         = args.dataset.dim_outcomes
        self.input_size          = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome

        # device
        self.device_ori = f'cuda:{args.exp.gpu}'
        print("use gpu:", self.device_ori)
        
        # -------------------------------------------------------
        # Learning 
        # -------------------------------------------------------
        # epochs
        self.max_epochs       = args.exp.max_epochs
        
        # sub_model parameters
        if self.isDecoder:
            sub_args = args.model.decoder
        else:
            sub_args = args.model.encoder
            
        # optimiser parameters
        self.optimizer_cls = sub_args.optimizer.optimizer_cls
        self.learning_rate = sub_args.optimizer.learning_rate
        self.weight_decay  = sub_args.optimizer.weight_decay
        self.lr_scheduler  = sub_args.optimizer.lr_scheduler
        
        # batch_size
        self.train_batch_size = sub_args.train_batch_size

        # Used for gradient-reversal
        self.alpha        = args.exp.alpha          
        self.update_alpha = args.exp.update_alpha
            
        # EMA
        self.weights_ema  = args.exp.weights_ema  # Exponential moving average of weights
        self.beta         = args.exp.beta         # EMA beta
        
        if self.isDecoder:
            self.weight_mode  = None
    #
    # train GAN
    #
    def train_GAN(self):
        seed_everything(10)
        max_epoch   = self.max_epochs
        alpha_max   = self.alpha
        
        ###########################################################
        # get optimizers                                          #
        ###########################################################
        if self.lr_scheduler:
            optimizers, lr_schedulers = self.configure_optimizers()
        else:
            optimizers = self.configure_optimizers()
        
        non_discriminator_optimizer, disc_treatment_optimizer, disc_dosage_optimizers = optimizers
        
        ###########################################################
        # 学習                                                    #
        ###########################################################
        all_loss = {}
        for epoch in range(1, max_epoch + 1):

            epoch_loss = [0.0, 0.0, 0.0]

            # alpha
            p = float(epoch + 1) / float(max_epoch)
            alpha = (2 / (1. + math.exp(-10. * p )) - 1.0) * alpha_max
            
            for batch in self.train_dataloader():
                # --------------------------------------------
                # generator
                # --------------------------------------------
                mse_loss, bce_loss_treat, bce_loss_dosage = self.training_step(batch = batch,
                                                                               optimizer_idx = 0) 
                # generator loss       
                g_loss = mse_loss + alpha * bce_loss_treat + alpha * bce_loss_dosage
                
                # optimize generator parameters
                non_discriminator_optimizer.zero_grad()
                
                # backward
                g_loss.backward()
                
                # optimize
                non_discriminator_optimizer.step()
                
                # store loss 
                epoch_loss[0] += (g_loss * batch['active_entries'].sum()).to('cpu').detach().numpy().copy()
                
                # --------------------------------------------
                # discriminator(treat)
                # --------------------------------------------
                if alpha > 0.0:
                    bce_loss_treat, bce_loss_dosages = self.training_step(batch = batch,
                                                                          optimizer_idx = 1)
                    bce_loss_treat *= alpha # add

                    # zero grad
                    disc_treatment_optimizer.zero_grad()

                    # backward
                    bce_loss_treat.backward()

                    # optimize
                    disc_treatment_optimizer.step() 

                    # store loss
                    epoch_loss[1] += (bce_loss_treat * batch['active_entries'].sum()).to('cpu').detach().numpy().copy()

                    # --------------------------------------------
                    # discriminator(dosage)
                    # -------------------------------------------- 
                    for w in range(self.num_treatments - 1):

                        bce_loss_dosages[w] *= alpha # add

                        # zero grad
                        disc_dosage_optimizers[w].zero_grad()

                        # backward
                        bce_loss_dosages[w].backward()

                        # optimize
                        disc_dosage_optimizers[w].step() 

                    all_bce_loss_dosage = 0.0
                    for w in range(self.num_treatments - 1):
                        all_bce_loss_dosage += (bce_loss_dosages[w] * batch['active_entries'].sum())

                    epoch_loss[2] += all_bce_loss_dosage.to('cpu').detach().numpy().copy()
                    
                # ------------------------------
                # update EMA
                # ------------------------------
                if self.weights_ema:
                    self.ema_non_discriminator.update()
                    self.ema_discriminator.update()
          
            all_loss[epoch] = epoch_loss
            
            # ------------------------------
            # update schedulter
            # ------------------------------
            if self.lr_scheduler:
                for scheduler in lr_schedulers:
                    scheduler.step()
                            
            # -----------------------------------------------------------------------------
            # print result
            # ---------------------------------------------------------------
            if self.isDecoder:
                print("Decoder Epoch: {:3d}  alpha: {:.5f}".format(epoch, alpha), end = "")            
            else:
                print("Encoder Epoch: {:3d}  alpha: {:.5f}".format(epoch, alpha), end = "")     

            # Loss
            print("\tMSE Loss: {:>10.0f}, Disc Loss: treament {:>10.0f}, dosage {:>10.0f}".format(
                epoch_loss[0], epoch_loss[1], epoch_loss[2]))
            
            # validation
            rmse_valid = self.validation_step()
            
        # test data
        rmse_test = self.test_step()

        if self.isDecoder:
            ret_rmse = {
                "valid_multi": rmse_valid,
                "test_multi" : rmse_test
            }
        else:
            ret_rmse = {
                "valid_one": rmse_valid,
                "test_multi" : rmse_test
            }    
            
        return ret_rmse, all_loss
                    
    # --------------------------------------------------------------------------------------
    # Training step (Pytorch lightning Module)
    # --------------------------------------------------------------------------------------    
    def training_step(self, 
                      batch, 
                      optimizer_idx: int = None):
 
        self.train()
        ###########################################################
        # optimize Generator                                      #
        ###########################################################
        if optimizer_idx == 0:
           # predict
            if self.weights_ema:
                with self.ema_discriminator.average_parameters():
                    factual_x_next, _, D_treatment_logit, D_dosage_logits = self(batch)
            else:
                factual_x_next, _, D_treatment_logit, D_dosage_logits = self(batch)
            
            # MSE loss
            mse_loss = self.get_mse_loss(factual_x_next = factual_x_next, 
                                         out_x_next     = batch['out_x_next']
                                        )
            # bce loss (treat)
            bce_loss_treat = self.get_bce_loss_treat(D_treatment_logit,
                                                     batch['current_ow'],
                                                     batch['unscaled_x'],
                                                     isGeneratorLoss = True)

            # bce loss (dosage)
            D_dosage_logit_sum = self.get_dosage_logit_sum(D_dosage_logits)
            bce_loss_dosage    = self.get_bce_loss_dosage(D_dosage_logit_sum,
                                                          batch['current_od'],
                                                          batch['unscaled_x'],
                                                          isGeneratorLoss = True) 
            
            
            # active entries mask
            active_entries  = batch['active_entries']
            current_iw_mask = (batch['current_iw'] != 0)
            active_mask = active_entries * current_iw_mask
            #
            mse_loss        = (active_entries * mse_loss).sum() / active_entries.sum()
            bce_loss_treat  = (bce_loss_treat * active_entries).sum() / active_entries.sum() 
            bce_loss_dosage = (bce_loss_dosage * active_mask).sum() / active_mask.sum()
            return mse_loss, bce_loss_treat, bce_loss_dosage

        ###########################################################
        # optimize Discriminator                                  #
        ###########################################################
        if optimizer_idx == 1:
            if self.weights_ema:
                with self.ema_non_discriminator.average_parameters():
                    _, _, D_treatment_logit, D_dosage_logits = self(batch, detached = True)
            else:
                _, _, D_treatment_logit, D_dosage_logits = self(batch, detached = True)

            # active entries mask
            active_entries  = batch['active_entries']
            current_iw_mask = (batch['current_iw'] != 0)
            active_mask = active_entries * current_iw_mask
                
            # ----------------------------------------------------------------------------------
            # treat
            # ----------------------------------------------------------------------------------
            bce_loss_treat = self.get_bce_loss_treat(D_treatment_logit,
                                                     batch['current_ow'],
                                                     batch['unscaled_x'],
                                                     isGeneratorLoss = False)  
            
            bce_loss_treat  = (bce_loss_treat * active_entries).sum() / active_entries.sum() 
            
            # ----------------------------------------------------------------------------------
            # dosage
            # ---------------------------------------------------------------------------------- 
            bce_loss_dosages = []
            for w in range(1, self.num_treatments):
                bce_loss_dosage = self.get_bce_loss_dosage(D_dosage_logits[w],
                                                           batch['current_od'],
                                                           batch['unscaled_x'],
                                                           isGeneratorLoss = False) 
                      
                current_iw_mask = (batch['current_iw'] == w)                
                active_mask = active_entries * current_iw_mask
                bce_loss_dosage = (bce_loss_dosage * active_mask).sum() / active_mask.sum()
                
                bce_loss_dosages.append(bce_loss_dosage)
                
            return bce_loss_treat, bce_loss_dosages
       
    
    
    def validation_step(self):
        self.eval()
        # Normalized RMSE
        if self.isDecoder:
            rmse_f_valid_multi, _ = self.get_multi_step_factual_rmse()
            print("\t validation   F: ", end="")
            for tau, value in rmse_f_valid_multi.items():
                print("{}: {:.3f}".format(tau, value), end=', ')
            print("")
            #
            return rmse_f_valid_multi
        else:
            rmse_f_valid = self.get_one_step_factual_rmse()
            rmse_cf_test = self.get_one_step_factual_rmse("test_cf")
            print("\t validation data: {:.5f}  test data: {:.5f}".format(rmse_f_valid, rmse_cf_test))
            
            return rmse_f_valid
    
    def test_step(self):
        self.eval()
        rmse_cf_test_multi, _ = self.get_multi_step_counterfactual_rmse()
        return rmse_cf_test_multi
     # ------------------------------------------------------------------------------------------------------
    # mse loss
    # ------------------------------------------------------------------------------------------------------
    def get_mse_loss(self, factual_x_next, out_x_next):        
        # calculate mse loss
        mse_loss = F.mse_loss(factual_x_next, out_x_next, reduce = False)
        #mse_loss = (active_entries * mse_loss).sum() / active_entries.sum()
        
        return mse_loss
    
   # ---------------------------------------------------
    # bce loss (Treatment)
    # ---------------------------------------------------
    def get_bce_loss_treat(self, D_treatment_logits, current_w, unscaled_x_next,
                           isGeneratorLoss = True):
                
        if isGeneratorLoss:
            # mask
            # f = 0, cf = 1 (invert)
            treatment_mask = current_w.detach().clone()
            treatment_mask = 1 - treatment_mask  
            
            # weight
            weight =  torch.ones_like(current_w)
        else:    
            # mask
            # f = 1, cf = 0 
            treatment_mask = current_w.detach().clone()

            # weight
            weight = current_w.detach().clone()
            weight = weight * 0.50 + 0.25
            
        # weight by max_cancer_volume        
        weight = weight * unscaled_x_next  / self.MAX_CANCER_VOLUME
        
        # calculate bce loss
        bce_loss = binary_cross_entropy_with_logits(D_treatment_logits,
                                                    treatment_mask,
                                                    weight = weight,
                                                    reduction='none')
        return bce_loss
    

    # ---------------------------------------------------
    # bce loss (Dosage)   
    # ---------------------------------------------------    
    def get_bce_loss_dosage(self,
                            D_dosage_logit,
                            #current_iw_mask,
                            current_od,
                            unscaled_x,
                            isGeneratorLoss = True):
                                    
        if isGeneratorLoss:
            # mask f = 0, cf = 1 (invert)
            dosage_mask = current_od.detach().clone()
            dosage_mask = 1 - dosage_mask  
            #
            weight =  torch.ones_like(current_od)

        else:    
            # mask f = 1, cf = 0 
            dosage_mask = current_od.detach().clone()
            #
            weight =  torch.ones_like(current_od) / 2.0

        # weight by max_cancer_volume 
        weight = weight * unscaled_x / self.MAX_CANCER_VOLUME

        # calculate bce loss            
        bce_loss = binary_cross_entropy_with_logits(D_dosage_logit,
                                                    dosage_mask,
                                                    weight = weight,
                                                    reduction='none')
        #
        #active_mask = active_entries * current_iw_mask
        #bce_loss = (bce_loss * active_mask).sum() / active_mask.sum()
             
        return bce_loss
    
    # ---------------------------------------------------
    # get_dosage_logit_sum
    # --------------------------------------------------- 
    def get_dosage_logit_sum(self, gene_D_dosage_logits):
        gene_D_dosage_logit_sum = torch.zeros_like(gene_D_dosage_logits[1])
        for w in range(1, self.num_treatments):
            gene_D_dosage_logit_sum += gene_D_dosage_logits[w]
            
        return gene_D_dosage_logit_sum