import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader, Dataset
import numpy as np
from omegaconf import DictConfig
from src.models.utils import bce
from typing import Union

import sys
sys.path.append('/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/trend_seasonality_causal_structure/src/')

from layers.Autoformer_EncDec import series_decomp
from src.data import RealDatasetCollection, SyntheticDatasetCollection
from src.data.cancer_sim.dataset import SyntheticCancerDatasetCollection
import logging


logger = logging.getLogger(__name__)

class Model(LightningModule):
    """
    Paper link: https://arxiv.org/pdf/2205.13504.pdf
    """
    model_type = 'multi'
    
    def __init__(self, 
                 args: DictConfig, 
                 dataset_collection: Union[SyntheticDatasetCollection, SyntheticCancerDatasetCollection, RealDatasetCollection] = None, 
                 **kwargs):
        """
        individual: Bool, whether shared model among different variates.
        """
        super().__init__()
        self.dataset_collection = dataset_collection
        self.args = args
        self.save_hyperparameters(args)
        self.task_name = args.model.multi.task_name
        self.seq_len = args.model.multi.seq_len        # len of input in time
        if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
            self.pred_len = args.model.multi.seq_len
        else:
            self.pred_len = args.model.multi.pred_len
        # Series decomposition block from Autoformer
        self.decompsition = series_decomp(args.model.multi.moving_avg)
        self.autoregressive = args.dataset.autoregressive
        self.individual = args.model.multi.individual
        self.enc_in = args.model.multi.enc_in
        self.has_vitals = args.model.multi.has_vitals
        
        self.dim_outcome = args.model.multi.dims_outcome
        self.dim_treatment = args.model.multi.dims_treatment
        self.dim_static = args.model.multi.dims_static
        self.dim_vitals = args.model.multi.dims_vitals if self.has_vitals else None
        self.dim_input = self.dim_treatment
        self.dim_input += self.dim_vitals if self.has_vitals else 0
        self.dim_input += self.dim_outcome

        if self.individual:
            self.Linear_Seasonal = nn.ModuleList()
            self.Linear_Trend = nn.ModuleList()

            for i in range(self.enc_in):
                self.Linear_Seasonal.append(
                    nn.Linear(self.seq_len, self.pred_len))
                self.Linear_Trend.append(
                    nn.Linear(self.seq_len, self.pred_len))

                self.Linear_Seasonal[i].weight = nn.Parameter(
                    (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
                self.Linear_Trend[i].weight = nn.Parameter(
                    (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
        else:
            self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
            self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)

            self.Linear_Seasonal.weight = nn.Parameter(
                (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
            self.Linear_Trend.weight = nn.Parameter(
                (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
        
        

        self.static_transformation = nn.Linear(1, 
                                               self.pred_len)
        self.treat_transformation = nn.Linear(self.seq_len, 
                                               self.pred_len)
        
        self.outcome_layer = nn.Linear(self.dim_input*2 + self.dim_static + self.dim_treatment, self.dim_outcome)
        self.treat_layer = nn.Linear(self.dim_input + self.dim_static, self.dim_treatment)





        if self.task_name == 'classification':
            self.act = F.gelu
            self.dropout = nn.Dropout(args.model.multi.dropout)
            self.projection = nn.Linear(
                args.model.multi.enc_in * args.model.multi.seq_len, args.model.multi.num_class)
            

    def encoder(self, batch):
        prev_outputs = batch['prev_outputs']                   # B X T X input_dims
        prev_treatments = batch['prev_treatments']             # B X T X treat_dims
        static_features = batch['static_features']             # B X static_dims
        curr_treatments = batch['current_treatments']          # B X T X treat_dims
        vitals = batch['vitals'] if self.has_vitals else None  # B X T X vitals_dims
        nan_mask = ~prev_outputs.isnan().any(axis=-1)
        prev_outputs[~nan_mask] = 0
        
        
        
        static_features = self.static_transformation(static_features.unsqueeze(-1))     # B X static_dims X T
        curr_treatments = curr_treatments.transpose(2, 1)                               # B X treat_dims X T


        cur_covariate = torch.cat((prev_outputs, prev_treatments), dim=-1)               # B X (dim_output + dim_treat) X T
        if self.has_vitals:
            cur_covariate = torch.cat((cur_covariate, vitals), dim=-1)                   # B X (dim_output + dim_treat + dim_vitals) X T
        
        seasonal_init, trend_init = self.decompsition(cur_covariate)
        seasonal_init, trend_init = seasonal_init.permute(
            0, 2, 1), trend_init.permute(0, 2, 1)
        

        
        if self.individual:
            seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
                                          dtype=seasonal_init.dtype).to(seasonal_init.device)
            trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
                                       dtype=trend_init.dtype).to(trend_init.device)
            for i in range(self.enc_in):
                seasonal_output[:, i, :] = self.Linear_Seasonal[i](
                    seasonal_init[:, i, :])
                trend_output[:, i, :] = self.Linear_Trend[i](
                    trend_init[:, i, :])
        else:
            seasonal_output = self.Linear_Seasonal(seasonal_init)
            trend_output = self.Linear_Trend(trend_init)
            

        
        est_outcome = self.outcome_layer(torch.cat([trend_output, seasonal_output, static_features, curr_treatments], dim=1).transpose(2, 1))
        est_treatment = self.treat_layer(torch.cat([trend_output, static_features], dim=1).transpose(2, 1))
        
                
        return est_outcome, est_treatment, trend_output.transpose(2, 1), seasonal_output.transpose(2, 1)

    def forecast(self, x_enc):
        # Encoder
        est_outcome, est_treatment, trend_output, seasonal_output = self.encoder(x_enc)
        return est_outcome, est_treatment, trend_output, seasonal_output

    def imputation(self, x_enc):
        # Encoder
        return self.encoder(x_enc)

    def anomaly_detection(self, x_enc):
        # Encoder
        return self.encoder(x_enc)

    def classification(self, x_enc):
        # Encoder
        enc_out = self.encoder(x_enc)
        # Output
        # (batch_size, seq_length * d_model)
        output = enc_out.reshape(enc_out.shape[0], -1)
        # (batch_size, num_classes)
        output = self.projection(output)
        return output

    def forward(self, x_enc, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            est_outcome, est_treatment, trend_output, seasonal_output = self.forecast(x_enc)
            return est_outcome[:, -self.pred_len:, :], est_treatment, trend_output, seasonal_output  # [B, L, D]
        if self.task_name == 'imputation':
            dec_out = self.imputation(x_enc)
            return dec_out  # [B, L, D]
        if self.task_name == 'anomaly_detection':
            dec_out = self.anomaly_detection(x_enc)
            return dec_out  # [B, L, D]
        if self.task_name == 'classification':
            dec_out = self.classification(x_enc)
            return dec_out  # [B, N]
        return None
    
    def generate_continuous_mask(self, B, T, n=5, l=0.1):
        res = torch.full((B, T), True, dtype=torch.bool)
        if isinstance(n, float):
            n = int(n * T)
        n = max(min(n, T // 2), 1)
        
        if isinstance(l, float):
            l = int(l * T)
        l = max(l, 1)
        
        for i in range(B):
            for _ in range(n):
                t = np.random.randint(T-l+1)
                res[i, t:t+l] = False
        return res


    def generate_binomial_mask(self, B, T, p=0.5):
        return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool)
    
    def _get_optimizer(self, param_optimizer: list):
        no_decay = ['bias', 'layer_norm']
        sub_args = self.hparams.model[self.model_type]
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                'weight_decay': sub_args['optimizer']['weight_decay'],
            },
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ]
        lr = sub_args['optimizer']['learning_rate']
        optimizer_cls = sub_args['optimizer']['optimizer_cls']
        if optimizer_cls.lower() == 'adamw':
            optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == 'adam':
            optimizer = optim.Adam(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == 'sgd':
            optimizer = optim.SGD(optimizer_grouped_parameters, lr=lr,
                                  momentum=sub_args['optimizer']['momentum'])
        else:
            raise NotImplementedError()

        return optimizer

    def _get_lr_schedulers(self, optimizer):
        if not isinstance(optimizer, list):
            lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
            return [optimizer], [lr_scheduler]
        else:
            lr_schedulers = []
            for opt in optimizer:
                lr_schedulers.append(optim.lr_scheduler.ExponentialLR(opt, gamma=0.99))
            return optimizer, lr_schedulers

    def configure_optimizers(self):
        optimizer = self._get_optimizer(list(self.named_parameters()))

        if self.hparams.model[self.model_type]['optimizer']['lr_scheduler']:
            return self._get_lr_schedulers(optimizer)

        return optimizer
    
    def train_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            # print("auto-regressive dataset!!")
            return DataLoader(self.dataset_collection.train_f, shuffle=True, batch_size=self.args.dataset.train_batch_size, drop_last=True)
        else:
            # print("non-autoregressive dataset!!")
            return DataLoader(self.dataset_collection.train_f_non, shuffle=True, batch_size=self.args.dataset.train_batch_size, drop_last=True)
    
    def val_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            return DataLoader(self.dataset_collection.val_f, shuffle=False, batch_size=self.args.dataset.val_batch_size, drop_last=True)
        else:
            return DataLoader(self.dataset_collection.val_f_non, shuffle=False, batch_size=self.args.dataset.val_batch_size, drop_last=True)
    
    def predict_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            return DataLoader(self.dataset_collection.test_cf_one_step, shuffle=False, batch_size=self.args.dataset.test_batch_size, drop_last=True)
        else:
            return DataLoader(self.dataset_collection.test_cf_one_step_non, shuffle=False, batch_size=self.args.dataset.test_batch_size, drop_last=True)
    
    def compute_p_t(self, batch):
        # p_t: [T, 4]

        p_t = torch.zeros(batch['prev_treatments'].shape[1], batch['prev_treatments'].shape[2])
        for cur_ts in range(0, batch['prev_treatments'].shape[1]):
            for cur_b in range(0, batch['prev_treatments'].shape[0]):
                for cur_t in range(0, batch['prev_treatments'].shape[2]):
                    if batch['prev_treatments'][cur_b, cur_ts, cur_t] == 1:
                        p_t[cur_ts, cur_t] += 1
        p_t /= batch['prev_treatments'].shape[0]
        return p_t
    
    def compute_factual_loss(self, est, true, pi_0, p_t, t):
        # pi_0: [B, T, 4]
        # p_t:  [T, 4]
        # est, true: [B, T, 1]
        # t: [B, T, 4]

        
        ''' Compute sample reweighting '''
        sample_weight = torch.zeros_like(est)
        for cur_ts in range(0, t.shape[1]):
            for cur_t1 in range(0, t.shape[-1]):
                idx_cur_t = (t[:, cur_ts, :] == cur_t1).squeeze()[:, 0]
                sample_weight[idx_cur_t, cur_ts, 0] = 1.
                if cur_ts > 1:
                    for cur_t2 in  range(0, t.shape[-1]):
                        if cur_t2 != cur_t1:
                            if p_t[cur_ts, cur_t2] < 1e-4 or (pi_0[idx_cur_t, cur_ts, cur_t1] < 1e-4).any():
                                continue
                            sample_weight[idx_cur_t, cur_ts, 0] += (pi_0[idx_cur_t, cur_ts, cur_t2])/(pi_0[idx_cur_t, cur_ts, cur_t1] + 1e-7) * (p_t[cur_ts, cur_t1]/(p_t[cur_ts, cur_t2] + 1e-7))
        
        return torch.mean(sample_weight * torch.square(est - true))
    
    
    def compute_cross_entropy_loss(self, est, target):
        # target: treatment (0 or 1 one-hot vector)
        # est: class probability
        
        
        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(est, target)
        
        return loss
    
    def mmd2_lin_dfr(self, upsilon, t, p_t):
        ''' Linear MMD '''
        # t: treatment variable [B, T, 4]
        # upsilon: data         [B, T, 1]
        # p_t: probability        [T, 4]

        
        mmd = 0
        for cur_ts in range(0, t.shape[1]):
            list_upsilon_mean = []
            for cur_tr in range(0, t.shape[-1]):
                idx_cur_t = (t[:, cur_ts, cur_tr] == 1).squeeze()   # current time stamp에서 current treatment가 1인 index
                
                list_upsilon_mean.append(torch.mean(upsilon[idx_cur_t, cur_ts, :], dim=0)) # mean ([T])
            for i in range(1, len(list_upsilon_mean)):
                if cur_ts > 1:
                    flag = True
                    for i in range(len(list_upsilon_mean)):
                        if torch.isnan(list_upsilon_mean[i]).any():
                            flag = False
                    if flag:
                        mmd += torch.sum(torch.square(2.0*p_t[cur_ts, i]*list_upsilon_mean[i] - 2.0*p_t[cur_ts, i-1]*list_upsilon_mean[i-1]))
        
        return mmd
    
    def get_normalised_masked_rmse(self, dataset: Dataset, one_step_counterfactual=True):
        outputs_scaled = self.get_predictions(dataset)
        
        unscale = self.args.exp.unscale_rmse
        percentage = self.args.exp.percentage_rmse
        
        if unscale:
            print("unscale")
        else:
            print("not unscale")
            
        if percentage:
            print("percentage error")
        else:
            print("not percentage error")

        if unscale:
            output_stds, output_means = dataset.scaling_params['output_stds'], dataset.scaling_params['output_means']
            outputs_unscaled = outputs_scaled * output_stds + output_means
            mse = ((outputs_unscaled - dataset.data['unscaled_outputs']) ** 2) * dataset.data['active_entries']
        else:
            # Batch-wise masked-MSE calculation is tricky, thus calculating for full dataset at once
            mse = ((outputs_scaled - dataset.data['outputs']) ** 2) * dataset.data['active_entries']

        # Calculation like in original paper (Masked-Averaging over datapoints (& outputs) and then non-masked time axis)
        mse_step = np.sqrt(mse.sum(0).sum(-1) / dataset.data['active_entries'].sum(0).sum(-1)) / dataset.norm_const
        mse_orig = mse.sum(0).sum(-1) / dataset.data['active_entries'].sum(0).sum(-1)
        mse_orig = mse_orig.mean()
        rmse_normalised_orig = np.sqrt(mse_orig) / dataset.norm_const

        # Masked averaging over all dimensions at once
        mse_all = mse.sum() / dataset.data['active_entries'].sum()
        rmse_normalised_all = np.sqrt(mse_all) / dataset.norm_const

        if percentage:
            rmse_normalised_orig *= 100.0
            rmse_normalised_all *= 100.0
            mse_step *= 100

        if one_step_counterfactual:
            # Only considering last active entry with actual counterfactuals
            num_samples, time_dim, output_dim = dataset.data['active_entries'].shape
            last_entries = dataset.data['active_entries'] - np.concatenate([dataset.data['active_entries'][:, :-1, :],
                                                                            np.zeros((num_samples, 1, output_dim))], axis=1)
            if unscale:
                mse_last = ((outputs_unscaled - dataset.data['unscaled_outputs']) ** 2) * last_entries
            else:
                mse_last = ((outputs_scaled - dataset.data['outputs']) ** 2) * last_entries

            mse_last = mse_last.sum() / last_entries.sum()
            rmse_normalised_last = np.sqrt(mse_last) / dataset.norm_const

            if percentage:
                rmse_normalised_last *= 100.0

            return rmse_normalised_orig, rmse_normalised_all, rmse_normalised_last, mse_step

        return rmse_normalised_orig, rmse_normalised_all
    
    
    def bce_loss(self, treatment_pred, current_treatments, kind='predict'):
        mode = self.args.dataset.treatment_mode
        bce_weights = torch.tensor(self.bce_weights).type_as(current_treatments) if self.hparams.exp.bce_weight else None

        if kind == 'predict':
            bce_loss = bce(treatment_pred, current_treatments, mode, bce_weights)
        elif kind == 'confuse':
            uniform_treatments = torch.ones_like(current_treatments)
            if mode == 'multiclass':
                uniform_treatments *= 1 / current_treatments.shape[-1]
            elif mode == 'multilabel':
                uniform_treatments *= 0.5
            bce_loss = bce(treatment_pred, uniform_treatments, mode)
        else:
            raise NotImplementedError()
        return bce_loss
    
    
    def training_step(self, batch, batch_ind, optimizer_idx=0):
        for par in self.parameters():
            par.requires_grad = True
        
        outcome_pred, treatment_pred, trend, season = self(batch)

        
        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        
        num_samples, time_dim, output_dim = batch['active_entries'].shape
        last_entries = batch['active_entries'] - torch.concat([batch['active_entries'][:, :-1, :],
                                                                 torch.zeros((num_samples, 1, output_dim)).to(treatment_pred.device)], dim=1)
        
        p_t = self.compute_p_t(batch)
       
        
        imbalanced_loss = self.mmd2_lin_dfr(season, batch['current_treatments'], p_t)
        
        
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
        
        loss = self.args.exp.param_lambda1 * bce_loss + mse_loss + self.args.exp.param_lambda2 * imbalanced_loss
        
        subset_name = self.train_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_imbalanced_loss', imbalanced_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        # self.log(f'{subset_name}_factual_loss', factual_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
            
        return loss
        
        
    def validation_step(self, batch, batch_ind):
        outcome_pred, treatment_pred, trend, season = self(batch)

        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        
        num_samples, time_dim, output_dim = batch['active_entries'].shape
        last_entries = batch['active_entries'] - torch.concat([batch['active_entries'][:, :-1, :],
                                                                 torch.zeros((num_samples, 1, output_dim)).to(treatment_pred.device)], dim=1)
        
        
        p_t = self.compute_p_t(batch)
        
        
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
        
        imbalanced_loss = self.mmd2_lin_dfr(season, batch['current_treatments'], p_t)
        

        
        loss = self.args.exp.param_lambda1 * bce_loss + mse_loss + self.args.exp.param_lambda2 * imbalanced_loss
        
        subset_name = self.val_dataloader().dataset.subset_name
        # print(subset_name)
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_imbalanced_loss', imbalanced_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        # self.log(f'{subset_name}_factual_loss', factual_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)

    def test_step(self, batch, batch_ind, **kwargs):
        treatment_pred, outcome_pred, _, _ = self(batch)
        
        p_t = self.compute_p_t(batch)
        
        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        imbalanced_loss = self.mmd2_lin_dfr(batch['prev_outputs'], batch['current_treatments'], p_t)
        
        bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
        # mse_loss = (batch['active_entries'] * mse_loss).sum() / batch['active_entries'].sum()
        
        loss = bce_loss + mse_loss + imbalanced_loss
        
        subset_name = self.val_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_bce_loss', bce_loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True)

    def predict_step(self, batch, batch_idx, dataset_idx=0):
        """
        Generates normalised output predictions
        """
        outcome_pred, treatment_pred, trend, season = self(batch)
        return outcome_pred.cpu(), treatment_pred.cpu(), trend.cpu(), season.cpu()
    
    def get_predictions(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.test_batch_size, shuffle=False)
        outcome_pred, _, _, _ = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return outcome_pred.numpy()
    
    def get_trend_seasonality(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.test_batch_size, shuffle=False)
        _, _, trend, seasonality = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return trend.numpy(), seasonality.numpy()
    
    def get_treatment(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.test_batch_size, shuffle=False)
        _, treatment, _, _ = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return treatment.numpy()
    
    def get_predictions_treatments(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.test_batch_size, shuffle=False)
        outcome_pred, treatment, _, _ = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return outcome_pred.numpy(), treatment.numpy()
