import torch
from src.models.utils import bce


def compute_p_t(batch):
    # p_t: [T, 4]

    p_t = torch.zeros(batch['prev_treatments'].shape[1], batch['prev_treatments'].shape[2])
    for cur_ts in range(0, batch['prev_treatments'].shape[1]):
        for cur_b in range(0, batch['prev_treatments'].shape[0]):
            for cur_t in range(0, batch['prev_treatments'].shape[2]):
                if batch['prev_treatments'][cur_b, cur_ts, cur_t] == 1:
                    p_t[cur_ts, cur_t] += 1
    p_t /= batch['prev_treatments'].shape[0]
    return p_t

def compute_factual_loss(est, true, pi_0, p_t, t):
    # pi_0: [B, T, 4]
    # p_t:  [T, 4]
    # est, true: [B, T, 1]
    # t: [B, T, 4]
    
    # w_t = t/(2.*p_t)                      # 기존 논문에 나와 있는 weight
    # w_c = (1.-t)/(2.*(1.-p_t))            
    
    # sample_weight = 1.*( 1. + (1.-self.pi_0)/self.pi_0 * (p_t/(1.-p_t))**(2.*t-1.) ) * (w_t+w_c)    # ???
    
    
    ''' Compute sample reweighting '''
    sample_weight = torch.zeros_like(est)
    for cur_ts in range(0, t.shape[1]):
        for cur_t1 in range(0, t.shape[-1]):
            idx_cur_t = (t[:, cur_ts, :] == cur_t1).squeeze()[:, 0]
            sample_weight[idx_cur_t, cur_ts, 0] = 1.
            if cur_ts > 1:
                for cur_t2 in  range(0, t.shape[-1]):
                    if cur_t2 != cur_t1:
                        if p_t[cur_ts, cur_t2] < 1e-4 or (pi_0[idx_cur_t, cur_ts, cur_t1] < 1e-4).any():
                            continue
                        sample_weight[idx_cur_t, cur_ts, 0] += (pi_0[idx_cur_t, cur_ts, cur_t2])/(pi_0[idx_cur_t, cur_ts, cur_t1] + 1e-7) * (p_t[cur_ts, cur_t1]/(p_t[cur_ts, cur_t2] + 1e-7))
    
    return torch.mean(sample_weight * torch.square(est - true))


def compute_cross_entropy_loss(self, est, target):
    # target: treatment (0 or 1 one-hot vector)
    # est: class probability
    
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(est, target)
    
    return loss

def mmd2_lin_dfr(upsilon, t, p_t):
    ''' Linear MMD '''
    # t: treatment variable [B, T, 4]
    # X: data               [B, T, 1]
    # p_t: probability        [T, 4]

    mmd = 0
    for cur_ts in range(0, t.shape[1]):
        list_upsilon_mean = []
        for cur_tr in range(0, t.shape[-1]):
            idx_cur_t = (t[:, cur_ts, cur_tr] == 1).squeeze()
            list_upsilon_mean.append(torch.mean(upsilon[idx_cur_t, cur_ts, :], dim=0))
        # print(list_X_mean)   # nan
        for i in range(1, len(list_upsilon_mean)):
            if cur_ts > 1:
                flag = True
                for i in range(len(list_upsilon_mean)):
                    if torch.isnan(list_upsilon_mean[i]).any():
                        flag = False
                if flag:
                    mmd += torch.sum(torch.square(2.0*p_t[cur_ts, i]*list_upsilon_mean[i] - 2.0*p_t[cur_ts, i-1]*list_upsilon_mean[i-1]))
    
    return mmd

def bce_loss(self, treatment_pred, current_treatments, kind='predict'):
    mode = self.hparams.dataset.treatment_mode
    bce_weights = torch.tensor(self.bce_weights).type_as(current_treatments) if self.hparams.exp.bce_weight else None

    if kind == 'predict':
        bce_loss = bce(treatment_pred, current_treatments, mode, bce_weights)
    elif kind == 'confuse':
        uniform_treatments = torch.ones_like(current_treatments)
        if mode == 'multiclass':
            uniform_treatments *= 1 / current_treatments.shape[-1]
        elif mode == 'multilabel':
            uniform_treatments *= 0.5
        bce_loss = bce(treatment_pred, uniform_treatments, mode)
    else:
        raise NotImplementedError()
    return bce_loss