import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from copy import deepcopy
from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

# pytorch
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

# Models
from buildEDBR import BuildEDBR
from GANDiscGenHead import GenCFModel, DiscTreatModel, DiscDosageModel
from GANTrain import GANTrain

# --------------------------------------------------------------------------------------------------------
# Transformer + Counterfactual Generator
# --------------------------------------------------------------------------------------------------------
class EDTSEncoder(GANTrain):
    # ----------------------------------------------------------------------------------------------------
    # init 
    # ----------------------------------------------------------------------------------------------------
    def __init__(self, 
                 args: DictConfig,
                 dataset_collection: dict,
                 **kwargs):
        
        # check
        if args.model.name != "EDTS":
            print("Model mistach")
            raise Exception()
        
        self.model_name = args.model.name
        self.isDecoder  = False
        self.base_model = "transformer" 
        super().__init__(args)
    
        # -------------------------------------------------------------------------------
        # dataset collection
        # -------------------------------------------------------------------------------        
        for data_name, data_list in dataset_collection.items():
            if isinstance(data_list, tuple):
                continue
            
            for key, value in data_list.items():
                dataset_collection[data_name][key] = value.to(self.device_ori)
        
        self.train_f       = dataset_collection['train_f']
        self.valid_f       = dataset_collection['valid_f']
        self.test_cf       = dataset_collection['test_cf']
        self.test_cf_multi = dataset_collection['test_cf_multi']
        
        # -------------------------------------------------------------------------------
        # dataset parameters
        # -------------------------------------------------------------------------------
        # scaling params
        self.train_scaling_params = dataset_collection['train_scaling_params']
        #
        self.num_treatments     = args.dataset.dim_treatments 
        self.num_dosage_samples = args.dataset.dim_n_dosage_samples
        
        # dict_wd
        self.dict_wd = {}
        train_means, train_stds = self.train_scaling_params
        for w, dosages in args.dataset.dict_wd.items():
            dosages = (dosages - train_means["treat_dosage"]) / train_stds["treat_dosage"]
            self.dict_wd[w] = dosages
            
        # max cancer volume
        self.MAX_CANCER_VOLUME    = args.dataset.max_cancer_volume
        
        # projection_horizon
        self.projection_horizon   = args.dataset.projection_horizon
                
        # CF generator
        self.multitask_input_dim   = args.model.encoder.br_size + args.model.encoder.dim_noise + args.dataset.dim_dosages
        self.multitask_hidden_dim  = args.model.encoder.fc_hidden_units
        
        # Discriminator
        self.br_size                 = args.model.encoder.br_size
        self.i_inv_eqv_dim           = self.dim_outcome + self.dim_dosages
        self.h_inv_eqv_dim           = args.model.encoder.fc_hidden_units
        self.discTreat_fc_hidden_dim = args.model.encoder.fc_hidden_units
                
        # -------------------------------------------------------------------------------
        # Model setup
        # -------------------------------------------------------------------------------
        # Transformer
        self.buildBR = BuildEDBR(args).to(self.device_ori)
        
        
        # generator
        self.genCF = GenCFModel(self.num_treatments,
                                self.num_dosage_samples,
                                self.dict_wd,
                                self.multitask_input_dim,
                                self.multitask_hidden_dim).to(self.device_ori)

        
        # -------------------------------------------------------------------------------
        # discriminator
        # -------------------------------------------------------------------------------
        # treatment
        self.discTreat = DiscTreatModel(self.dict_wd,
                                        self.br_size, 
                                        self.discTreat_fc_hidden_dim,
                                        self.i_inv_eqv_dim, 
                                        self.h_inv_eqv_dim).to(self.device_ori)
        
    
        # dosage
        discDosages = []
        for w in range(1, self.num_treatments):
            discDosage = DiscDosageModel(self.dict_wd[w],
                                            self.br_size,
                                            self.i_inv_eqv_dim,
                                            self.h_inv_eqv_dim).to(self.device_ori)
            discDosages.append(discDosage)

        self.discDosages_list = nn.ModuleList(discDosages)
             
    
class EDTSDecoder(GANTrain):
    # ----------------------------------------------------------------------------------------------------
    # init 
    # ----------------------------------------------------------------------------------------------------
    def __init__(self, 
                 args: DictConfig,
                 encoder: EDTSEncoder,
                 train_f_aug = None,
                 **kwargs):

        self.model_name = args.model.name
        self.isDecoder = True
        self.base_model = "transformer"
        super().__init__(args)
        
        self.encoder  = encoder
        
        # -------------------------------------------------------------------------------
        # dataset collection
        # -------------------------------------------------------------------------------
        self.train_f_sequential = self.process_sequential(self.encoder.train_f)
        self.train_f       = self.encoder.train_f
        self.valid_f       = self.encoder.valid_f 
        self.test_cf       = self.encoder.test_cf
        self.test_cf_multi = self.encoder.test_cf_multi
               
        # -------------------------------------------------------------------------------
        # dataset parameters
        # -------------------------------------------------------------------------------
        self.num_treatments     = self.encoder.num_treatments
        self.num_dosage_samples = self.encoder.num_dosage_samples
        self.dict_wd            = self.encoder.dict_wd
        
        # scaling params
        self.train_scaling_params = self.encoder.train_scaling_params

        # max cancer volume
        self.MAX_CANCER_VOLUME    = self.encoder.MAX_CANCER_VOLUME
        
        # projection_horizon
        self.projection_horizon   = self.encoder.projection_horizon
                        
        # CF generator
        self.multitask_input_dim   = args.model.decoder.br_size + args.model.encoder.dim_noise + args.dataset.dim_dosages
        self.multitask_hidden_dim  = args.model.decoder.fc_hidden_units
        
        # Discriminator
        self.br_size                 = args.model.decoder.br_size
        self.i_inv_eqv_dim           = self.dim_outcome + self.dim_dosages
        self.h_inv_eqv_dim           = args.model.decoder.fc_hidden_units
        self.discTreat_fc_hidden_dim = args.model.decoder.fc_hidden_units
            
        # -------------------------------------------------------------------------------
        # Model setup
        # -------------------------------------------------------------------------------
        #
        # build Transformer (Decoder)
        #
        self.buildBR = BuildEDBR(args, isDecoder = True).to(self.device_ori)
        

        # generator
        self.genCF = GenCFModel(self.num_treatments,
                                self.num_dosage_samples,
                                self.dict_wd,
                                self.multitask_input_dim,
                                self.multitask_hidden_dim).to(self.device_ori)

        
        # -------------------------------------------------------------------------------
        # discriminator
        # -------------------------------------------------------------------------------
        # treatment
        self.discTreat = DiscTreatModel(self.dict_wd,
                                        self.br_size, 
                                        self.discTreat_fc_hidden_dim,
                                        self.i_inv_eqv_dim, 
                                        self.h_inv_eqv_dim).to(self.device_ori)
        
    
        # dosage
        discDosages = []
        for w in range(1, self.num_treatments):
            discDosage = DiscDosageModel(self.dict_wd[w],
                                            self.br_size,
                                            self.i_inv_eqv_dim,
                                            self.h_inv_eqv_dim).to(self.device_ori)
            discDosages.append(discDosage)

        self.discDosages_list = nn.ModuleList(discDosages)
    