import numpy as np
import pandas as pd
import torch
from torch import nn
import torchtuples as tt

from pycox import models
from pycox.preprocessing.label_transforms import LabTransCoxTime

class CoxTime(models.cox_cc._CoxCCBase):
    """The Cox-Time model from [1]. A relative risk model without proportional hazards, trained
    with case-control sampling.
    
    Arguments:
        net {torch.nn.Module} -- A PyTorch net.
    
    Keyword Arguments:
        optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None})
        device {str, int, torch.device} -- Device to compute on. (default: {None})
            Preferably pass a torch.device object.
            If 'None': use default gpu if available, else use cpu.
            If 'int': used that gpu: torch.device('cuda:<device>').
            If 'string': string is passed to torch.device('string').
        shrink {float} -- Shrinkage that encourage the net got give g_case and g_control
            closer to zero (a regularizer in a sense). (default: {0.})
        labtrans {pycox.preprocessing.label_tranforms.LabTransCoxTime} -- A object for transforming
            durations. Useful for prediction as we can obtain durations on the original scale.
            (default: {None})

    References:
    [1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel.
        Time-to-event prediction with neural networks and Cox regression.
        Journal of Machine Learning Research, 20(129):1–30, 2019.
        http://jmlr.org/papers/v20/18-424.html
    """
    make_dataset = models.data.CoxTimeDataset
    label_transform = LabTransCoxTime

    def __init__(self, net, optimizer=None, device=None, shrink=0., labtrans=None, loss=None):
        self.labtrans = labtrans
        super().__init__(net, optimizer, device, shrink, loss)

    def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
        input, durations = input
        input = tt.tuplefy(input)
        durations = tt.tuplefy(durations)
        new_input = input + durations 
        dataloader = super().make_dataloader_predict(new_input, batch_size, shuffle, num_workers)
        return dataloader

    def predict_surv_df(self, input, max_duration=None, batch_size=256, verbose=False, baseline_hazards_=None,
                        eval_=True, num_workers=0):
        surv = super().predict_surv_df(input, max_duration, batch_size, verbose, baseline_hazards_,
                                       eval_, num_workers)
        if self.labtrans is not None:
            surv.index = self.labtrans.map_scaled_to_orig(surv.index)
        return surv

    def compute_baseline_hazards(self, input=None, target=None, max_duration=None, sample=None, batch_size=256,
                                set_hazards=True, eval_=True, num_workers=0):
        if (input is None) and (target is None):
            if not hasattr(self, 'training_data'):
                raise ValueError('Need to fit, or supply a input and target to this function.')
            input, target = self.training_data
        df = self.target_to_df(target)
        if sample is not None:
            if sample >= 1:
                df = df.sample(n=sample)
            else:
                df = df.sample(frac=sample)
            df = df.sort_values(self.duration_col)
        input = tt.tuplefy(input).to_numpy().iloc[df.index.values]
        base_haz = self._compute_baseline_hazards(input, df, max_duration, batch_size, eval_, num_workers)
        if set_hazards:
            self.compute_baseline_cumulative_hazards(set_hazards=True, baseline_hazards_=base_haz)
        return base_haz

    def _compute_baseline_hazards(self, input, df_train_target, max_duration, batch_size, eval_=True,
                                  num_workers=0):
        if max_duration is None:
            max_duration = np.inf
        def compute_expg_at_risk(ix, t):
            sub = input.iloc[ix:]
            n = sub.lens().flatten().get_if_all_equal()
            t = np.repeat(t, n).reshape(-1, 1).astype('float32')
            return np.exp(self.predict((sub, t), batch_size, True, eval_, num_workers=num_workers)).flatten().sum()

        if not df_train_target[self.duration_col].is_monotonic_increasing:
            raise RuntimeError(f"Need 'df_train_target' to be sorted by {self.duration_col}")
        input = tt.tuplefy(input)
        df = df_train_target.reset_index(drop=True)
        times = (df
                 .loc[lambda x: x[self.event_col] != 0]
                 [self.duration_col]
                 .loc[lambda x: x <= max_duration]
                 .drop_duplicates(keep='first'))
        at_risk_sum = (pd.Series([compute_expg_at_risk(ix, t) for ix, t in times.items()],
                                 index=times.values)
                       .rename('at_risk_sum'))
        events = (df
                  .groupby(self.duration_col)
                  [[self.event_col]]
                  .agg('sum')
                  .loc[lambda x: x.index <= max_duration])
        base_haz =  (events
                     .join(at_risk_sum, how='left', sort=True)
                     .pipe(lambda x: x[self.event_col] / x['at_risk_sum'])
                     .fillna(0.)
                     .rename('baseline_hazards'))
        return base_haz

    def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_,
                                    eval_=True, num_workers=0):
        def expg_at_time(t):
            t = np.repeat(t, n_cols).reshape(-1, 1).astype('float32')
            if tt.tuplefy(input).type() is torch.Tensor:
                t = torch.from_numpy(t)
            return np.exp(self.predict((input, t), batch_size, True, eval_, num_workers=num_workers)).flatten()

        if tt.utils.is_dl(input):
            raise NotImplementedError(f"Prediction with a dataloader as input is not supported ")
        input = tt.tuplefy(input)
        max_duration = np.inf if max_duration is None else max_duration
        baseline_hazards_ = baseline_hazards_.loc[lambda x: x.index <= max_duration]
        n_rows, n_cols = baseline_hazards_.shape[0], input.lens().flatten().get_if_all_equal()
        hazards = np.empty((n_rows, n_cols))
        for idx, t in enumerate(baseline_hazards_.index):
            if verbose:
                print(idx, 'of', len(baseline_hazards_))
            hazards[idx, :] = expg_at_time(t)
        hazards[baseline_hazards_.values == 0] = 0.  # in case hazards are inf here
        hazards *= baseline_hazards_.values.reshape(-1, 1)
        return pd.DataFrame(hazards, index=baseline_hazards_.index).cumsum()

    def partial_log_likelihood(self, input, target, batch_size=256, eval_=True, num_workers=0):
        def expg_sum(t, i):
            sub = input_sorted.iloc[i:]
            n = sub.lens().flatten().get_if_all_equal()
            t = np.repeat(t, n).reshape(-1, 1).astype('float32')
            return np.exp(self.predict((sub, t), batch_size, True, eval_, num_workers=num_workers)).flatten().sum()

        durations, events = target
        df = pd.DataFrame({self.duration_col: durations, self.event_col: events})
        df = df.sort_values(self.duration_col)
        input = tt.tuplefy(input)
        input_sorted = input.iloc[df.index.values]

        times =  (df
                  .assign(_idx=np.arange(len(df)))
                  .loc[lambda x: x[self.event_col] == True]
                  .drop_duplicates(self.duration_col, keep='first')
                  .assign(_expg_sum=lambda x: [expg_sum(t, i) for t, i in zip(x[self.duration_col], x['_idx'])])
                  .drop([self.event_col, '_idx'], axis=1))
        
        idx_name_old = df.index.name
        idx_name = '__' + idx_name_old if idx_name_old else '__index'
        df.index.name = idx_name

        pll = df.loc[lambda x: x[self.event_col] == True]
        input_event = input.iloc[pll.index.values]
        durations_event = pll[self.duration_col].values.reshape(-1, 1)
        g_preds = self.predict((input_event, durations_event), batch_size, True, eval_, num_workers=num_workers).flatten()
        pll = (pll
               .assign(_g_preds=g_preds)
               .reset_index()
               .merge(times, on=self.duration_col)
               .set_index(idx_name)
               .assign(pll=lambda x: x['_g_preds'] - np.log(x['_expg_sum']))
               ['pll'])

        pll.index.name = idx_name_old
        return pll


class MLPVanillaCoxTime(nn.Module):
    """A version of torchtuples.practical.MLPVanilla that works for CoxTime.
    The difference is that it takes `time` as an additional input and removes the output bias and
    output activation.
    """
    def __init__(self, in_features, num_nodes, batch_norm=True, dropout=None, activation=nn.ReLU,
                 w_init_=lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')):
        super().__init__()
        in_features += 1
        out_features = 1
        output_activation = None
        output_bias=False
        self.net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout,
                                           activation, output_activation, output_bias, w_init_)

    def forward(self, input, time):
        input = torch.cat([input, time], dim=1)
        return self.net(input)


class MixedInputMLPCoxTime(nn.Module):
    """A version of torchtuples.practical.MixedInputMLP that works for CoxTime.
    The difference is that it takes `time` as an additional input and removes the output bias and
    output activation.
    """
    def __init__(self, in_features, num_embeddings, embedding_dims, num_nodes, batch_norm=True,
                 dropout=None, activation=nn.ReLU, dropout_embedding=0.,
                 w_init_=lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')):
        super().__init__()
        in_features += 1
        out_features = 1
        output_activation = None
        output_bias=False
        self.net = tt.practical.MixedInputMLP(in_features, num_embeddings, embedding_dims, num_nodes,
                                              out_features, batch_norm, dropout, activation,
                                              dropout_embedding, output_activation, output_bias, w_init_)

    def forward(self, input_numeric, input_categoric, time):
        input_numeric = torch.cat([input_numeric, time], dim=1)
        return self.net(input_numeric, input_categoric)
