import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

# pytorch
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from BRTrain                import BRTrain
from BRTreatmentOutcomeHead import grad_reverse, BROutcomeHead, BRTreatmentHead, BRDosageHead
from model.base.buildBR     import BuildBR

# --------------------------------------------------------------------------------------------------------
# Transformer + Counterfactual Generator
# --------------------------------------------------------------------------------------------------------
class CT(BRTrain):
    # ----------------------------------------------------------------------------------------------------
    # init 
    # ----------------------------------------------------------------------------------------------------
    def __init__(self, 
                 args: DictConfig,
                 dataset_collection: dict,
                 **kwargs):
        
        # check
        if args.model.name != "CT":
            print("Model mistach")
            raise Exception()
            
        self.model_name = args.model.name
        self.isDecoder  = False
        self.base_model = "transformer" 
        super().__init__(args)
    
        # -------------------------------------------------------------------------------
        # dataset collection
        # -------------------------------------------------------------------------------
        for data_name, data_list in dataset_collection.items():
            if isinstance(data_list, tuple):
                continue
            
            for key, value in data_list.items():
                dataset_collection[data_name][key] = value.to(self.device_ori)
        
        self.train_f       = dataset_collection['train_f']
        self.valid_f       = dataset_collection['valid_f']
        self.test_cf       = dataset_collection['test_cf']
        self.test_cf_multi = dataset_collection['test_cf_multi']
        
        # scaling params
        self.train_scaling_params = dataset_collection['train_scaling_params']
    
        # -------------------------------------------------------------------------------
        # Model parameters
        # -------------------------------------------------------------------------------
        # Encoder/decoder-specific parameters
        self.br_size          = args.model.encoder.br_size          # balanced representation size
        self.seq_hidden_units = args.model.encoder.seq_hidden_units
        self.fc_hidden_units  = args.model.encoder.fc_hidden_units
        #
        self.dropout_rate     = args.model.encoder.dropout_rate
        self.num_layer        = args.model.encoder.num_layer  
                            
        # -------------------------------------------------------------------------------
        # Model setup
        # -------------------------------------------------------------------------------
        # Transformer
        self.buildBr = BuildBR(args).to(self.device_ori)
                
        # -------------------------------------------------------------------------------
        # discriminator
        # -------------------------------------------------------------------------------
        # generator
        self.brOutcomeHead = BROutcomeHead(self.br_size, 
                                           self.fc_hidden_units,
                                           self.dim_treatments, 
                                           self.dim_dosages,
                                           self.dim_outcome).to(self.device_ori)
        
        # discriminator (treatment)
        self.brTreamentHead = BRTreatmentHead(self.br_size,
                                              self.fc_hidden_units, 
                                              self.dim_treatments, 
                                              self.alpha,
                                              self.update_alpha,
                                              self.balancing).to(self.device_ori)
    
        # dosage
        self.brDosageHead = BRDosageHead(self.br_size,
                                            self.fc_hidden_units, 
                                            self.dim_treatments, 
                                            self.num_dosage_samples,
                                            self.alpha,
                                            self.update_alpha,
                                            self.balancing).to(self.device_ori)  