import numpy as np
from copy import deepcopy
import math
from tqdm import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from pytorch_lightning import seed_everything

# Model
from BRCausalModel import BRCausalModel

#######################################################################################
# Pytorch Lighting Modules                                                            #
#######################################################################################
class BRTrain(BRCausalModel):
    def __init__(self, args):
        super().__init__()
        self.init_specific(args)
        
        # off automatic optimization
        self.automatic_optimization = False
        
    def init_specific(self, args):
        self.num_treatments     = args.dataset.dim_treatments 
        self.num_dosage_samples = args.dataset.dim_n_dosage_samples
                    
        # max cancer volume
        self.MAX_CANCER_VOLUME   = args.dataset.max_cancer_volume
        
        # projection_horizon
        self.projection_horizon  = args.dataset.projection_horizon
                
        # treatment_mode
        self.treatment_mode      = args.dataset.treatment_mode
        
        # 
        self.dim_treatments      = args.dataset.dim_treatments
        self.dim_dosages         = args.dataset.dim_dosages
        self.dim_static_features = args.dataset.dim_static_features
        self.dim_outcome         = args.dataset.dim_outcomes
        self.input_size          = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome

        # device
        self.device_ori = f'cuda:{args.exp.gpu}'
        print("use gpu:", self.device_ori)
      
        # -------------------------------------------------------
        # Learning 
        # -------------------------------------------------------
        # epochs
        self.max_epochs       = args.exp.max_epochs
        
        # sub_model parameters
        if self.isDecoder:
            sub_args = args.model.decoder
        else:
            sub_args = args.model.encoder
            
        # optimiser parameters
        self.optimizer_cls = sub_args.optimizer.optimizer_cls
        self.learning_rate = sub_args.optimizer.learning_rate
        self.weight_decay  = sub_args.optimizer.weight_decay
        self.lr_scheduler  = sub_args.optimizer.lr_scheduler
        
        # batch_size
        self.train_batch_size = sub_args.train_batch_size
            
        # Balancing representation training parameters
        self.balancing    = args.exp.balancing
        
        # Used for gradient-reversal
        self.alpha        = args.exp.alpha          
        self.update_alpha = args.exp.update_alpha
        
        # EMA
        self.weights_ema  = args.exp.weights_ema  # Exponential moving average of weights
        self.beta         = args.exp.beta         # EMA beta
        
    def train_BR(self):
        seed_everything(10)
        max_epoch   = self.max_epochs
        alpha_max   = self.alpha
        
        ###########################################################
        # get optimizers                                          #
        ###########################################################
        if self.lr_scheduler:
            optimizers, lr_schedulers = self.configure_optimizers()
        else:
            optimizers = self.configure_optimizers()
        
        if self.balancing == 'grad_reverse':
            g_opt = optimizers
        else:
            g_opt, w_opt, d_opt = optimizers
        
        ###########################################################
        # 学習                                                    #
        ###########################################################
        all_loss = {}
        for epoch in range(1, max_epoch + 1):

            epoch_loss = [0.0, 0.0, 0.0]

            # alpha
            p = float(epoch + 1) / float(max_epoch)
            alpha = (2 / (1. + math.exp(-10. * p )) - 1.0) * alpha_max
            
            for batch in self.train_dataloader():
                g_loss = self.training_step(batch = batch, 
                                            optimizer_idx = 0,
                                            br_treatment_outcome_head_alpha = alpha)  
                
                # optimize generator parameters
                g_opt.zero_grad()
                # backward
                g_loss.backward()
                # optimize
                g_opt.step()
                
                if self.balancing == 'domain_confusion':
                    w_treat_loss, d_dosage_loss = self.training_step(batch = batch,
                                                                     optimizer_idx = 1,
                                                                     br_treatment_outcome_head_alpha = alpha)  

                    # optimize generator parameters
                    w_opt.zero_grad()
                    # backward
                    w_treat_loss.backward()
                    # optimize
                    w_opt.step()   
                    
                    # --------------------------------------------
                    # discriminator(dosage)
                    # --------------------------------------------
                    # zero grad
                    d_opt.zero_grad()

                    # backward
                    d_dosage_loss.backward()

                    # optimize
                    d_opt.step()

                    bce_loss_dosage = d_dosage_loss * batch['active_entries'].sum()
                    
                # EMA update
                if self.weights_ema:
                    self.ema_non_discriminator.update()
                    self.ema_discriminator.update()
                                    
                epoch_loss[0] += (g_loss * batch['active_entries'].sum()).to('cpu').detach().numpy().copy()
                if self.balancing == 'domain_confusion':
                    epoch_loss[1] += (w_treat_loss * batch['active_entries'].sum()).to('cpu').detach().numpy().copy()

                    epoch_loss[2] += bce_loss_dosage.to('cpu').detach().numpy().copy()
                    
            all_loss[epoch] = epoch_loss

            # update scheduler
            if self.lr_scheduler:
                for scheduler in lr_schedulers:
                    scheduler.step()
                            
            # print learning information
            if self.isDecoder:
                print("Decoder Epoch: {:3d}  alpha: {:.5f}".format(epoch, alpha), end = "")            
            else:
                print("Encoder Epoch: {:3d}  alpha: {:.5f}".format(epoch, alpha), end = "")   

            # Loss
            print("\tMSE Loss: {:>10.0f}, Disc Loss: treament {:>10.0f}, dosage {:>10.0f}".format(
                epoch_loss[0], epoch_loss[1], epoch_loss[2]))

            # validation data
            rmse_valid = self.validation_step()
            
        # test data
        rmse_test = self.test_step()

        if self.isDecoder:
            ret_rmse = {
                "valid_multi": rmse_valid,
                "test_multi" : rmse_test
            }
        else:
            ret_rmse = {
                "valid_one": rmse_valid,
                "test_multi" : rmse_test
            }        

        return ret_rmse, all_loss
            
    # --------------------------------------------------------------------------------------
    # Training step (Pytorch lightning Module)
    # --------------------------------------------------------------------------------------    
    def training_step(self, 
                      batch, 
                      optimizer_idx: int = None,
                      br_treatment_outcome_head_alpha = None):
        #
        self.train()
        ###########################################################
        # optimize Generator                                      #
        ###########################################################
        if optimizer_idx == 0:
            if self.weights_ema:
                with self.ema_discriminator.average_parameters():
                    outcome_pred, br, treatment_pred, dosage_pred = self(batch)
            else:
                outcome_pred, br, treatment_pred, dosage_pred = self(batch)

            # MSE loss
            mse_loss = F.mse_loss(outcome_pred, batch['out_x_next'], reduce = False)

            # BCE treat loss
            if self.balancing == 'grad_reverse':
                bce_loss_treatment = self.get_bce_loss(treatment_pred, batch['current_ow'].double(), kind='predict')
                bce_loss_dosage    = self.get_bce_loss(dosage_pred, batch['current_od'].double(), kind='predict')
            elif self.balancing == 'domain_confusion':
                bce_loss_treatment = self.get_bce_loss(treatment_pred, batch['current_ow'].double(), kind='confuse')
                bce_loss_treatment = br_treatment_outcome_head_alpha * bce_loss_treatment
                
                bce_loss_dosage = self.get_bce_loss(dosage_pred, batch['current_od'].double(), kind='confuse')
                bce_loss_dosage = br_treatment_outcome_head_alpha * bce_loss_dosage
            else:
                raise NotImplementedError()

                
            mse_loss = (batch['active_entries'] * mse_loss).sum() / batch['active_entries'].sum()
                
            # treatment
            active_entries  = batch['active_entries']
            bce_loss_treatment = (active_entries * torch.unsqueeze(bce_loss_treatment, dim = -1)).sum() / active_entries.sum()
            
            # dosage
            current_iw_mask = (batch['current_iw'] != 0)
            active_mask = active_entries * current_iw_mask    
            bce_loss_dosage = (active_mask * torch.unsqueeze(bce_loss_dosage, dim = -1)).sum() / active_mask.sum()
            
            g_loss = mse_loss + bce_loss_treatment + bce_loss_dosage

            
            return g_loss

        ###########################################################
        # optimize Discriminator                                  #
        ###########################################################            
        elif optimizer_idx == 1:
            if self.weights_ema:
                with self.ema_non_discriminator.average_parameters():
                    outcome_pred, br, treatment_pred, dosage_pred = self(batch, detach = True)
            else:
                outcome_pred, br, treatment_pred, dosage_pred = self(batch, detach = True)

            # treatment
            bce_loss_treatment = self.get_bce_loss(treatment_pred, batch['current_ow'].double(), kind='predict')
            bce_loss_treatment = br_treatment_outcome_head_alpha * bce_loss_treatment
  
            # dosage
            bce_loss_dosage = self.get_bce_loss(dosage_pred, batch['current_od'].double(), kind='predict')
            bce_loss_dosage = br_treatment_outcome_head_alpha * bce_loss_dosage

            
            # treatment mask
            active_entries = batch['active_entries']
            bce_loss_treatment = (active_entries * torch.unsqueeze(bce_loss_treatment, dim = -1)).sum() / active_entries.sum()
            
            # dosage mask
            current_iw_mask = (batch['current_iw'] != 0)
            active_mask = active_entries * current_iw_mask    
            bce_loss_dosage = (active_mask * torch.unsqueeze(bce_loss_dosage, dim = -1)).sum() / active_mask.sum()
            
            
            return bce_loss_treatment, bce_loss_dosage

    def validation_step(self):
        self.eval()
        # Normalized RMSE
        if self.isDecoder:
            rmse_f_valid_multi, _ = self.get_multi_step_factual_rmse()
            print("\t evaluation F: ", end="")
            for tau, value in rmse_f_valid_multi.items():
                print("{}: {:.3f}".format(tau, value), end=', ')
            print("")
            #
            return rmse_f_valid_multi
        else:
            rmse_f_valid = self.get_one_step_factual_rmse()
            rmse_cf_test = self.get_one_step_factual_rmse("test_cf")
            print("\t evaluation: {:.5f}  test: {:.5f}".format(rmse_f_valid, rmse_cf_test))
            
            return rmse_f_valid
        
    def test_step(self):
        self.eval()
        rmse_cf_test_multi, _ = self.get_multi_step_counterfactual_rmse()
        return rmse_cf_test_multi
        
    # ---------------------------------------------------
    # get_dosage_logit_sum
    # --------------------------------------------------- 
    def get_dosage_logit_sum(self, gene_D_dosage_logits):
        gene_D_dosage_logit_sum = torch.zeros_like(gene_D_dosage_logits[1])
        for w in range(1, self.num_treatments):
            gene_D_dosage_logit_sum += gene_D_dosage_logits[w]
            
        return gene_D_dosage_logit_sum

    
