from woods.objectives.ERM import ERM
import torch
import torch.nn.functional as F


class AbstractCAD(ERM):
    """Contrastive adversarial domain bottleneck (abstract class)
    from Optimal Representations for Covariate Shift <https://arxiv.org/abs/2201.00057>
    """

    def __init__(self, model, dataset, optimizer, hparams, is_conditional):
        super(AbstractCAD, self).__init__(model, dataset, optimizer, hparams)
    # def __init__(self, input_shape, num_classes, num_domains,
    #              hparams, is_conditional):
    #     super(AbstractCAD, self).__init__(input_shape, num_classes, num_domains, hparams)

        # input_shape = dataset.INPUT_SHAPE
        # num_classes = dataset.OUTPUT_SIZE
        self.num_domains = len(dataset.ENVS)
        self.num_train_domains = self.num_domains - 1 if dataset.test_env is not None else self.num_domains

        # self.featurizer = networks.Featurizer(input_shape, self.hparams)
        # self.classifier = networks.Classifier(
        #     self.featurizer.n_outputs,
        #     num_classes,
        #     self.hparams['nonlinear_classifier'])
        # params = list(self.featurizer.parameters()) + list(self.classifier.parameters())
        self.model = model
        self.optimizer = optimizer

        # parameters for domain bottleneck loss
        self.is_conditional = is_conditional  # whether to use bottleneck conditioned on the label
        self.base_temperature = 0.07
        self.temperature = hparams['temperature']
        self.is_project = hparams['is_project']  # whether apply projection head
        self.is_normalized = hparams['is_normalized'] # whether apply normalization to representation when computing loss

        # whether flip maximize log(p) (False) to minimize -log(1-p) (True) for the bottleneck loss
        # the two versions have the same optima, but we find the latter is more stable
        self.is_flipped = hparams["is_flipped"]

        # if self.is_project:
        #     self.project = nn.Sequential(
        #         nn.Linear(feature_dim, feature_dim),
        #         nn.ReLU(inplace=True),
        #         nn.Linear(feature_dim, 128),
        #     )
        #     params += list(self.project.parameters())

        # # Optimizers
        # self.optimizer = torch.optim.Adam(
        #     params,
        #     lr=self.hparams["lr"],
        #     weight_decay=self.hparams['weight_decay']
        # )

    def bn_loss(self, z, y, dom_labels):
        """Contrastive based domain bottleneck loss
         The implementation is based on the supervised contrastive loss (SupCon) introduced by
         P. Khosla, et al., in “Supervised Contrastive Learning“.
        Modified from  https://github.com/HobbitLong/SupContrast/blob/8d0963a7dbb1cd28accb067f5144d61f18a77588/losses.py#L11
        """
        device = z.device

        # Flatten tensor (batch, time, ...) -> (batch*time, ...)
        z, y, dom_labels = z.view(-1, *z.shape[2:]), y.view(-1), dom_labels.view(-1)
        batch_size = z.shape[0]

        y = y.contiguous().view(-1, 1)
        dom_labels = dom_labels.contiguous().view(-1, 1)
        mask_y = torch.eq(y, y.T).to(device)
        mask_d = (torch.eq(dom_labels, dom_labels.T)).to(device)
        mask_drop = ~torch.eye(batch_size).bool().to(device)  # drop the "current"/"self" example
        mask_y &= mask_drop
        mask_y_n_d = mask_y & (~mask_d)  # contain the same label but from different domains
        mask_y_d = mask_y & mask_d  # contain the same label and the same domain
        mask_y, mask_drop, mask_y_n_d, mask_y_d = mask_y.float(), mask_drop.float(), mask_y_n_d.float(), mask_y_d.float()

        # compute logits
        if self.is_project:
            z = self.project(z)
        if self.is_normalized:
            z = F.normalize(z, dim=1)

        # For all prediction in the time series compute the CAD objective
        outer = z @ z.T
        logits = outer / self.temperature
        logits = logits * mask_drop
        # for numerical stability
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()

        if not self.is_conditional:
            # unconditional CAD loss
            denominator = torch.logsumexp(logits + mask_drop.log(), dim=1, keepdim=True)
            log_prob = logits - denominator

            mask_valid = (mask_y.sum(1) > 0)
            log_prob = log_prob[mask_valid]
            mask_d = mask_d[mask_valid]

            if self.is_flipped:  # maximize log prob of samples from different domains
                bn_loss = - (self.temperature / self.base_temperature) * torch.logsumexp(
                    log_prob + (~mask_d).float().log(), dim=1)
            else:  # minimize log prob of samples from same domain
                bn_loss = (self.temperature / self.base_temperature) * torch.logsumexp(
                    log_prob + (mask_d).float().log(), dim=1)
        else:
            # conditional CAD loss
            if self.is_flipped:
                mask_valid = (mask_y_n_d.sum(1) > 0)
            else:
                mask_valid = (mask_y_d.sum(1) > 0)

            mask_y = mask_y[mask_valid]
            mask_y_d = mask_y_d[mask_valid]
            mask_y_n_d = mask_y_n_d[mask_valid]
            logits = logits[mask_valid]

            # compute log_prob_y with the same label
            denominator = torch.logsumexp(logits + mask_y.log(), dim=1, keepdim=True)
            log_prob_y = logits - denominator

            if self.is_flipped:  # maximize log prob of samples from different domains and with same label
                bn_loss = - (self.temperature / self.base_temperature) * torch.logsumexp(
                    log_prob_y + mask_y_n_d.log(), dim=1)
            else:  # minimize log prob of samples from same domains and with same label
                bn_loss = (self.temperature / self.base_temperature) * torch.logsumexp(
                    log_prob_y + mask_y_d.log(), dim=1)

        def finite_mean(x):
            # only 1D for now
            num_finite = (torch.isfinite(x).float()).sum()
            mean = torch.where(torch.isfinite(x), x, torch.tensor(0.0).to(x)).sum()
            if num_finite != 0:
                mean = mean / num_finite
            else:
                return torch.tensor(0.0).to(x)
            return mean

        return finite_mean(bn_loss)

    def update(self):

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()
        device = X.device
        X_split, Y_split = self.dataset.split_tensor_by_domains(X, Y, self.num_train_domains)

        all_pred, all_z = self.model(X)

        all_d = torch.cat([
            torch.full(y.shape, i, dtype=torch.int64, device=device)
            for i, y in enumerate(Y_split)
        ])

        bn_loss = self.bn_loss(all_z, Y, all_d)
        clf_loss = self.dataset.loss(all_pred, Y)
        total_loss = clf_loss + self.hparams['lmbda'] * bn_loss

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

class CAD(AbstractCAD):
    """Contrastive Adversarial Domain (CAD) bottleneck
       Properties:
       - Minimize I(D;Z)
       - Require access to domain labels but not task labels
       """

    def __init__(self, model, dataset, optimizer, hparams):
        super(CAD, self).__init__(model, dataset, optimizer, hparams, is_conditional=False)
    # def __init__(self, input_shape, num_classes, num_domains, hparams):
    #     super(CAD, self).__init__(input_shape, num_classes, num_domains, hparams, is_conditional=False)


class CondCAD(AbstractCAD):
    """Conditional Contrastive Adversarial Domain (CAD) bottleneck
    Properties:
    - Minimize I(D;Z|Y)
    - Require access to both domain labels and task labels
    """
    def __init__(self, model, dataset, optimizer, hparams):
        super(CondCAD, self).__init__(model, dataset, optimizer, hparams, is_conditional=True)
    # def __init__(self, input_shape, num_classes, num_domains, hparams):
    #     super(CondCAD, self).__init__(input_shape, num_classes, num_domains, hparams, is_conditional=True)
