import numpy as np
from copy import deepcopy

import torch
from torch import nn

from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

from utils_lstm      import VariationalLSTM
from utils import set_data, set_data_multi

from RMSNtrain import RMSNtrain

class RMSN(RMSNtrain):
    model_type = None  # Will be defined in subclasses
    possible_model_types = {'encoder', 'decoder', 'propensity_treatment', 'propensity_history'}
    tuning_criterion = None

    def __init__(self, args: DictConfig, dataset_collection: dict, **kwargs):
        super().__init__(args)
        
        # -------------------------------------------------------------------------------
        # dataset collection
        # -------------------------------------------------------------------------------         
        for data_name, data_list in dataset_collection.items():
            if isinstance(data_list, tuple):
                continue
            
            for key, value in data_list.items():
                dataset_collection[data_name][key] = value.to(self.device_ori)
        
        self.train_f       = dataset_collection['train_f']
        self.valid_f       = dataset_collection['valid_f']
        self.test_cf       = dataset_collection['test_cf']
        self.test_cf_multi = dataset_collection['test_cf_multi']
        
        # scaling params
        self.train_scaling_params = dataset_collection['train_scaling_params']
        
    def _init_specific(self, sub_args: DictConfig, encoder_r_size: int = None):
        # Encoder/decoder-specific parameters
        try:
            self.seq_hidden_units = sub_args.seq_hidden_units
            self.dropout_rate = sub_args.dropout_rate
            self.num_layer = sub_args.num_layer

            # Pytorch model init
            if self.seq_hidden_units is None or self.dropout_rate is None:
                raise MissingMandatoryValue()

            if self.model_type == 'decoder':
                self.memory_adapter = nn.Linear(encoder_r_size, self.seq_hidden_units).to(self.device_ori)

            self.lstm = VariationalLSTM(self.input_size, 
                                        self.seq_hidden_units, 
                                        self.num_layer, 
                                        self.dropout_rate).to(self.device_ori)
            self.output_layer = nn.Linear(self.seq_hidden_units, 
                                          self.output_size).to(self.device_ori)

        except MissingMandatoryValue:
            logger.warning(f"{self.model_type} not fully initialised - some mandatory args are missing! "
                           f"(It's ok, if one will perform hyperparameters search afterward).")


    def clip_normalize_stabilized_weights(self, stabilized_weights, active_entries, multiple_horizons=False):
        """
        Used by RMSNs
        """
        active_entries = active_entries.to(torch.bool)
        
        stabilized_weights[~torch.squeeze(active_entries)] = torch.nan
        sw_tilde = torch.clip(stabilized_weights, torch.nanquantile(stabilized_weights, 0.01), torch.nanquantile(stabilized_weights, 0.99))
        if multiple_horizons:
            sw_tilde = sw_tilde / torch.nanmean(sw_tilde, axis=0, keepdim=True)
        else:
            sw_tilde = sw_tilde / torch.nanmean(sw_tilde)

        sw_tilde[~torch.squeeze(active_entries)] = 0.0
        return sw_tilde
    
    # ------------------------------------------------------------------------
    # Prepensity scoreの計算
    #  ------------------------------------------------------------------------
    def process_propensity_train_f(self, propensity_treatment, propensity_history):
        prop_treat_train_f = propensity_treatment.get_propensity_scores(self.train_f)
        prop_hist_train_f  = propensity_history.get_propensity_scores(self.train_f)
        sw = torch.prod(prop_treat_train_f / prop_hist_train_f, axis=2).to(self.device_ori)
        self.train_f['stabilized_weights'] = sw
        
    def get_propensity_scores(self, dataset: dict) -> np.array:
        if self.model_type == 'propensity_treatment' or self.model_type == 'propensity_history':
            propensity_scores = None
            for batch in self.val_dataloader():
                if propensity_scores is None:
                    propensity_scores = self.predict_step(batch)
                else:
                    propensity_scores = torch.cat((propensity_scores, 
                                                  self.predict_step(batch)),
                                                  dim = 0)
        else:
            raise NotImplementedError()
        return propensity_scores.detach()

    