import numpy as np
import torch
import torch.nn as nn
from pr_nn import diff_CSDI_PR


class CSDI_base_PR(nn.Module):
    def __init__(self, target_dim, config, device, logger=None):
        super().__init__()
        self.device = device
        self.target_dim = target_dim
        self.logger = logger

        self.is_pr = config['pr']['is_pr']

        config_pr = config["pr"]

        input_dim = 1
        self.diffmodel = diff_CSDI_PR(config_pr, input_dim)
        if self.logger is not None:
            self.logger.info(f"Num of total trainable pattern recognizer params is: {sum(p.numel() for p in self.diffmodel.parameters() if p.requires_grad)}")


    ##### TO-DO: observed_data -> generated_data / cond_mask -> true_mask / mse_loss -> cross entropy loss
    def calc_loss(
        self, generated_data, true_mask, observed_mask
    ):
        B, K, L = generated_data.shape
        fake_mask = 1.0 - true_mask

        total_input = generated_data.unsqueeze(1)
        pred_prob = self.diffmodel(total_input)  # (B,K,L)
        
        num_eval_true = true_mask.sum()
        num_eval_fake = fake_mask.sum()
        loss_true = -torch.sum(true_mask * torch.log(pred_prob + 1e-10)) /(num_eval_true if num_eval_true>0 else 1)
        loss_fake = -torch.sum(fake_mask * torch.log(1. - pred_prob + 1e-10)) /(num_eval_fake if num_eval_fake>0 else 1)
        
        return loss_true, loss_fake

    def forward(self, batch):
        (
            generated_data,
            _,
            observed_mask,
            observed_tp,
        ) = self.process_data(batch)
        true_mask = observed_mask

        loss_func = self.calc_loss

        return loss_func(generated_data, true_mask, observed_mask)


class CSDI_PR(CSDI_base_PR):
    def __init__(self, config, device, target_dim=7, logger=None):
        super(CSDI_PR, self).__init__(target_dim, config, device, logger)

    def process_data(self, batch):
        observed_data = batch["observed_data"].to(self.device).float()
        observed_mask = batch["observed_mask"].to(self.device).float()
        observed_tp = batch["timepoints"].to(self.device).float()

        observed_data = observed_data.permute(0, 2, 1)
        observed_mask = observed_mask.permute(0, 2, 1)
        target_mask = None

        return (
            observed_data,
            target_mask,
            observed_mask,
            observed_tp,
        )


