import numpy as np
from copy import deepcopy
from tqdm import tqdm

import torch
from torch import nn
import torch.nn.functional as F

from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

from utils_lstm      import VariationalLSTM
from utils import set_data, set_data_multi

from GNetTrain import GNetTrain

from BRTreatmentOutcomeHead import ROutcomeHead

from sklearn.model_selection import train_test_split

class GNet(GNetTrain):
    def __init__(self, args: DictConfig,
                 dataset_collection: dict = None):
        
        # check
        if args.model.name != "G-Net":
            print("Model mistach")
            raise Exception()
        
        self.model_name = args.model.name
        self.isDecoder  = False
        self.base_model = "lstm" 
        
        super().__init__(args)
        
        # -------------------------------------------------------------------------------
        # dataset collection
        # -------------------------------------------------------------------------------
        for data_name, data_list in dataset_collection.items():
            if isinstance(data_list, tuple):
                continue
            
            for key, value in data_list.items():
                dataset_collection[data_name][key] = value.to(self.device_ori)
        
        self.train_f       = dataset_collection['train_f']
        #
        self.split_train_f_holdout(self.holdout_ratio, args.exp.seed)
        
        self.valid_f       = dataset_collection['valid_f']
        self.test_cf       = dataset_collection['test_cf']
        self.test_cf_multi = dataset_collection['test_cf_multi']

        # scaling params
        self.train_scaling_params = dataset_collection['train_scaling_params']
        
        # 
        self.input_size  = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome 
        self.output_size = self.dim_outcome
        #
        self._init_specific(args.model.g_net)
        #self.prepare_data()

    def _init_specific(self, sub_args: DictConfig):
        try:
            self.dropout_rate     = sub_args.dropout_rate
            self.seq_hidden_units = sub_args.seq_hidden_units
            self.r_size           = sub_args.r_size
            self.num_layer        = sub_args.num_layer
            self.num_comp         = sub_args.num_comp
            self.fc_hidden_units  = sub_args.fc_hidden_units
            self.mc_samples       = sub_args.mc_samples
            
            self.comp_sizes       = [self.dim_outcome // self.num_comp] * self.num_comp

            # Params for Representation network
            if self.seq_hidden_units is None or self.r_size is None or self.dropout_rate is None or self.fc_hidden_units is None:
                raise MissingMandatoryValue()

            # Params for Conditional distribution networks
            assert len(self.comp_sizes) == self.num_comp
            assert sum(self.comp_sizes) == self.output_size

            # Representation network init + Conditional distribution networks init
            self.repr_net = VariationalLSTM(self.input_size, 
                                            self.seq_hidden_units, 
                                            self.num_layer, 
                                            self.dropout_rate).to(self.device_ori)
            # head
            self.r_outcome_vitals_head = ROutcomeHead(self.seq_hidden_units, 
                                                            self.r_size, 
                                                            self.fc_hidden_units,
                                                            self.dim_outcome, 
                                                            self.num_comp, 
                                                            self.comp_sizes).to(self.device_ori)

        except MissingMandatoryValue:
            logger.warning(f"{self.model_type} not fully initialised - some mandatory args are missing! "
                           f"(It's ok, if one will perform hyperparameters search afterward).")

    def forward(self, batch, sample=False):
        curr_treatments = batch['current_ow']
        curr_dosages    = batch['current_d']
        static_features = batch['inp_v']
        prev_outputs    = batch["inp_x"]
        
        # build r
        r = self.build_r(curr_treatments, curr_dosages, static_features, prev_outputs)
        # outcome pred
        outcome_pred = self.r_outcome_vitals_head.build_outcome_vitals(r)
        return outcome_pred, r

    def build_r(self, curr_treatments, curr_dosages, static_features, prev_outputs):
        x = torch.cat((curr_treatments, curr_dosages, static_features, prev_outputs), dim = -1)
        x = self.repr_net(x)
        r = self.r_outcome_vitals_head.build_r(x)
        return r

    def training_step(self, batch):
        self.train()
        outcome_next_vitals_pred, _ = self(batch)  # By convention order is (outcomes, vitals)
        outcome_pred             = outcome_next_vitals_pred[:, :, :self.dim_outcome]
        outcome_mse_loss         = F.mse_loss(outcome_pred, batch['out_x_next'], reduce=False)

        # Masking for shorter sequences
        mse_loss = (batch['active_entries'] * outcome_mse_loss).sum() / batch['active_entries'].sum()

        return mse_loss

    def predict_step(self, batch):
        return self(batch).cpu()

    # set resid
    def set_resid(self) -> None:    
        self.eval()
        outcome_next_vitals_pred, encoder_brs = self.encode_factual_dataset(self.train_f_holdout)
        outcomes_next_vitals     = self.train_f_holdout['out_x_next']
 
        # holdout resid
        self.holdout_resid = outcomes_next_vitals - outcome_next_vitals_pred
        self.holdout_resid_len = self.train_f_holdout['seq_lengths']

    def get_multi_step_counterfactual_gnet(self):
        assert hasattr(self, 'holdout_resid') and hasattr(self, 'holdout_resid_len')

        dataset_cf = self.test_cf_multi
        npatients, seq_lengths, nsamples, ntaus, _ = dataset_cf['unscaled_x_nahead'].shape
        # active_entries_next_cf
        active_entries_cf = dataset_cf["active_entries"]

        predict_x_cf = torch.zeros((self.mc_samples, npatients, seq_lengths, nsamples, ntaus, self.dim_outcome)).to(self.device_ori)

        for m in tqdm(range(self.mc_samples)):
            predict_x_cf[m] = self.get_autoregressive_predictions()
            torch.cuda.empty_cache()
            
        predict_x_cf = predict_x_cf.mean(0)  # Averaging over mc_samples
        
        # caluculate Normalized RMSE   
        train_means, train_stds = self.train_scaling_params
        predict_x_cf = predict_x_cf * train_stds["cancer_volume"] + train_means["cancer_volume"]
           
        rmses = {}
        for ntau in range(self.projection_horizon + 1):
            mse_cf  = ((predict_x_cf[:, :, :, ntau, :] - dataset_cf["unscaled_x_nahead"][:, :, :, ntau, :]) ** 2) * active_entries_cf[:,:,:, ntau, :]
            mse_cf_all  = mse_cf.sum() / active_entries_cf[:, :, :, ntau, :].sum()
            rmse_cf_all = 100.0 * (torch.sqrt(mse_cf_all) / self.MAX_CANCER_VOLUME)
            
            rmses[ntau + 1] = rmse_cf_all.to('cpu').detach().numpy().copy()
                        
        return rmses
    
    # ---------------------------------------------------------------------------------------
    # Normalized Masked RMSE CouterFactual multi
    # ---------------------------------------------------------------------------------------
    def get_autoregressive_predictions(self):        
        # set data
        dataset_f  = self.test_cf
        dataset_cf = self.test_cf_multi
        
        # shape
        npatients, seq_lengths, nsamples, ntaus, _ = dataset_cf['unscaled_x_nahead'].shape
        predict_x_cf = torch.zeros(npatients, seq_lengths, nsamples, ntaus, 1).to(self.device_ori)
        
        # statci features for cf 
        static_features    = dataset_f['inp_v']
        static_cf_features = deepcopy(static_features)
        static_cf_features = torch.cat([static_cf_features, static_cf_features[:, 0:ntaus-1, :]], 
                                       dim = 1).to(self.device_ori)
        
        # active_entries_next_cf
        active_entries_cf = dataset_cf["active_entries"]
        
        # one-step prediction (factual treatment) 
        outcome_pred, encoder_br = self.encode_factual_dataset(dataset_f)
        
        # ntau = 0
        for nsample in range(nsamples):
            # add resid
            rand_resid_ind = np.random.randint(len(self.holdout_resid), size = len(outcome_pred))
            resid_at_split = self.holdout_resid[rand_resid_ind]             
            predict_x_cf[:, :, nsample, 0, 0] = torch.squeeze(outcome_pred + resid_at_split)
            
        # multi-step prediction (couterfactual treatment)
        for t in range(seq_lengths):
            for ntau in range(1, self.projection_horizon + 1):
                for nsample in range(nsamples):
                    batchd = {}
                    fact_length = t + 1
                    
                    # future(counterfactual)
                    batchd['inp_x']          = predict_x_cf[:, t, nsample, :ntau]
                    batchd['inp_v']          = static_cf_features[:, fact_length:fact_length + ntau]  
                    batchd['inp_w_prev']     = dataset_cf['prev_ow'][:, t, nsample, :ntau]
                    batchd['inp_d_prev']     = dataset_cf['prev_d_scaled'][:, t, nsample, :ntau]
                    batchd['active_entries'] = dataset_cf["active_entries"][:, t, nsample, :ntau]
                    batchd['current_ow']     = dataset_cf['current_ow'][:, t, nsample, :ntau, :]
                    batchd['current_d']      = dataset_cf['current_d_scaled'][:, t, nsample, :ntau, :]   
                    
                    # combine past with futre
                    batchd['inp_x']          = torch.cat([dataset_f['inp_x'][:, :fact_length, :], batchd['inp_x']], dim = 1)
                    batchd['inp_v']          = static_cf_features[:, :fact_length + ntau, :]  
                    batchd['inp_w_prev']     = torch.cat([dataset_f['inp_w_prev'][:, :fact_length, :], batchd['inp_w_prev']], dim = 1)
                    batchd['inp_d_prev']     = torch.cat([dataset_f['inp_d_prev'][:, :fact_length, :], batchd['inp_d_prev']], dim = 1)
                    batchd['active_entries'] = torch.cat([dataset_f['active_entries'][:, :fact_length, :], batchd['active_entries']], dim = 1) 
                    batchd['current_ow']     = torch.cat([dataset_f['current_ow'][:, :fact_length, :], batchd['current_ow']], dim = 1)
                    batchd['current_d']  = torch.cat([dataset_f['current_d'][:, :fact_length, :], batchd['current_d']], dim = 1)
                    
                    # prediction
                    with torch.inference_mode():
                        results = self(batchd)
                        
                    outcome_pred = results[0]  
                    if t + ntau < seq_lengths:
                        rand_resid_ind = np.random.randint(len(self.holdout_resid), size = len(outcome_pred))
                        resid_at_split = self.holdout_resid[rand_resid_ind]
                        resid_at_split = resid_at_split[:, t + ntau]
                        predict_x_cf[:, t, nsample, ntau] = outcome_pred[:, -1] + resid_at_split
                    else:
                        predict_x_cf[:, t, nsample, ntau] = outcome_pred[:, -1]
                                                 
        return predict_x_cf
    
    
    # --------------------------------------------------------------------------------
    # prepare data
    # --------------------------------------------------------------------------------
    def prepare_data(self) -> None:
        self.split_train_f_holdout(self.hparams.dataset.holdout_ratio)
        self.explode_cf_treatment_seq(self.mc_samples)
        
    def split_train_f_holdout(self, holdout_ratio = 0.1, seed = 10):
        if holdout_ratio > 0.0:
            self.train_f_holdout = deepcopy(self.train_f)
            for k, v in self.train_f.items():
                self.train_f[k], self.train_f_holdout[k] = train_test_split(v, 
                                                                            test_size = holdout_ratio,
                                                                            random_state = seed)
    def explode_cf_treatment_seq(self, mc_samples = 1):
        self.test_cf_treatment_seq_mc = []
        for m in range(mc_samples):
            self.test_cf_treatment_seq_mc.append(self.test_cf_treatment_seq)
            self.test_cf_treatment_seq_mc[m].data = deepcopy(self.test_cf_treatment_seq.data)