import numpy as np
from copy import deepcopy
#
from pytorch_lightning import LightningModule
from omegaconf import DictConfig
from omegaconf.errors import MissingMandatoryValue
import math
#
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from BRTrain                import BRTrain
from BRTreatmentOutcomeHead import grad_reverse, BuildBr, BROutcomeHead, BRTreatmentHead, BRDosageHead


class CRNEncoder(BRTrain):
    def __init__(self, 
                 args: DictConfig,
                 dataset_collection: dict,
                 **kwargs):
        
        if args.model.name != "CRN":
            print("Model mistach")
            raise Exception()
        
        self.model_name = args.model.name
        self.isDecoder = False
        self.base_model = "lstm" 
        super().__init__(args)
        
        # -------------------------------------------------------------------------------
        # dataset collection
        # -------------------------------------------------------------------------------        
        for data_name, data_list in dataset_collection.items():
            if isinstance(data_list, tuple):
                continue
            
            for key, value in data_list.items():
                dataset_collection[data_name][key] = value.to(self.device_ori)
        
        self.train_f       = dataset_collection['train_f']
        self.valid_f       = dataset_collection['valid_f']
        self.test_cf       = dataset_collection['test_cf']
        self.test_cf_multi = dataset_collection['test_cf_multi']
        
        # scaling params
        self.train_scaling_params = dataset_collection['train_scaling_params']
        
        # -------------------------------------------------------------------------------
        # Model parameters
        # -------------------------------------------------------------------------------
        # Encoder/decoder-specific parameters
        self.br_size          = args.model.encoder.br_size          # balanced representation size
        self.seq_hidden_units = args.model.encoder.seq_hidden_units
        self.fc_hidden_units  = args.model.encoder.fc_hidden_units
        #
        self.dropout_rate     = args.model.encoder.dropout_rate
        self.num_layer        = args.model.encoder.num_layer            

        # -------------------------------------------------------------------------------
        # Model setup
        # -------------------------------------------------------------------------------
        # LSTM
        self.buildBr = BuildBr(self.input_size, 
                               self.seq_hidden_units,
                               self.num_layer,
                               self.dropout_rate,
                               self.br_size).to(self.device_ori)
        
        # generator
        self.brOutcomeHead = BROutcomeHead(self.br_size, 
                                           self.fc_hidden_units,
                                           self.dim_treatments, 
                                           self.dim_dosages,
                                           self.dim_outcome).to(self.device_ori)
        
        # discriminator (treatment)
        self.brTreamentHead = BRTreatmentHead(self.br_size,
                                              self.fc_hidden_units, 
                                              self.dim_treatments, 
                                              self.alpha,
                                              self.update_alpha,
                                              self.balancing).to(self.device_ori)
        
        self.brDosageHead = BRDosageHead(self.br_size,
                                            self.fc_hidden_units, 
                                            self.dim_treatments, 
                                            self.num_dosage_samples,
                                            self.alpha,
                                            self.update_alpha,
                                            self.balancing).to(self.device_ori)                        
                        
            
class CRNDecoder(BRTrain):
    def __init__(self, 
                 args: DictConfig,
                 encoder: CRNEncoder,
                 **kwargs):
        
        self.model_name = args.model.name
        self.isDecoder = True
        self.base_model = "lstm"
        super().__init__(args)
        
        self.encoder  = encoder
        
        # -------------------------------------------------------------------------------
        # dataset collection
        # -------------------------------------------------------------------------------
        self.train_f_sequential = self.process_sequential(self.encoder.train_f)
        self.valid_f            = self.encoder.valid_f
        #
        self.test_cf            = self.encoder.test_cf
        self.test_cf_multi      = self.encoder.test_cf_multi
        
        self.train_scaling_params = self.encoder.train_scaling_params
        
        # -------------------------------------------------------------------------------
        # Model parameters
        # -------------------------------------------------------------------------------   
        # Encoder/decoder-specific parameters
        self.br_size          = args.model.decoder.br_size          # balanced representation size
        self.fc_hidden_units  = args.model.decoder.fc_hidden_units
        #
        self.seq_hidden_units = self.encoder.br_size              # from encoder
        self.dropout_rate     = args.model.decoder.dropout_rate
        self.num_layer        = args.model.decoder.num_layer  
                
        # -------------------------------------------------------------------------------
        # Model setup
        # -------------------------------------------------------------------------------
        # LSTM
        self.buildBr = BuildBr(self.input_size, 
                               self.seq_hidden_units,
                               self.num_layer,
                               self.dropout_rate,
                               self.br_size).to(self.device_ori)
        
        # generator
        self.brOutcomeHead = BROutcomeHead(self.br_size, 
                                           self.fc_hidden_units,
                                           self.dim_treatments, 
                                           self.dim_dosages,
                                           self.dim_outcome).to(self.device_ori)
        
        # discriminator (treatment)
        self.brTreamentHead = BRTreatmentHead(self.br_size,
                                              self.fc_hidden_units, 
                                              self.dim_treatments,
                                              self.alpha,
                                              self.update_alpha,
                                              self.balancing).to(self.device_ori)
        
        self.brDosageHead = BRDosageHead(self.br_size,
                                            self.fc_hidden_units, 
                                            self.dim_treatments, 
                                            self.num_dosage_samples,
                                            self.alpha,
                                            self.update_alpha,
                                            self.balancing).to(self.device_ori)                        
                         