import logging

import torch
import torch.nn.functional as F

from tqdm import tqdm


class WNet(torch.nn.Module):

    def __init__(self, feature_dim, num_groups):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_groups = num_groups

        self.model = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, num_groups, bias=False),
        )

    def forward(self, x):
        x = self.model(x)
        return x

def calibrate_with_tau_and_w_logits(logits,
                                    features,
                                    tau,
                                    b=None,
                                    hard=True,
                                    w_net=None,
                                    w_logits=None):
    assert (w_logits is not None) != (w_net is not None)

    N = logits.shape[0]
    if hard:
        num_groups = w_net.num_groups
        group_log_softmax = torch.log_softmax(w_net(features), dim=1)
        group_hard_prob = F.gumbel_softmax(group_log_softmax, tau=1, hard=hard)
        temp_logits = logits.view((N, 1)) / tau.view((1, num_groups))
        if b is not None:
            temp_logits += b
        temp_log_sigmoid = F.logsigmoid(temp_logits)
        calibrated_logits = torch.sum(temp_log_sigmoid * group_hard_prob, dim=1)
        return calibrated_logits, group_hard_prob

    else:

        if w_logits is not None:
            num_groups = w_logits.shape[1]
            group_softmax = torch.log_softmax(
                w_logits, dim=1).view((N, num_groups))
        else:
            num_groups = w_net.num_groups
            group_softmax = torch.softmax(
                w_net(features), dim=1).view((N, num_groups))

        temp_logits = logits.view((N, 1)) / tau.view((1, num_groups))
        if b is not None:
            temp_logits += b
        temp_log_sigmoid = F.logsigmoid(temp_logits)
        calibrated_logits = torch.logsumexp(group_softmax * temp_log_sigmoid, dim=1)
        return calibrated_logits, group_softmax

def optimize_group_fn(
        features,
        logits,
        labels,
        w_net,
        hard_group,
        method_config):
    if isinstance(w_net, str):
        train_w = True
        assert isinstance(method_config.num_groups, int)
        w_net = WNet(feature_dim=features.shape[1],
                     num_groups=method_config.num_groups)
    else:
        train_w = False
        assert isinstance(w_net, torch.nn.Module)

    tau = torch.nn.Parameter(torch.tensor(
        [1.0] * method_config.num_groups,
        requires_grad=True, device=features.device))
    b = torch.nn.Parameter(torch.tensor(
        [0.1] * method_config.num_groups,
        requires_grad=True, device=features.device))

    if train_w:
        params = [tau, b] + list(w_net.parameters())
    else:
        params = [tau, b]

    optimizer = torch.optim.LBFGS(params,
                                  line_search_fn="strong_wolfe",
                                  max_iter=method_config.optimizer.steps)

    w_net = w_net.to(features.device).double()

    def closure():
        optimizer.zero_grad()

        # Calculate weight decay loss
        reg_weight_decay = 0
        for name, param in w_net.named_parameters():
            if "weight" in name:
                reg_weight_decay += torch.mean((param) ** 2)
        reg_weight_decay_loss = reg_weight_decay * method_config.w_net.weight_decay

        # Calculate NLL loss
        calibrated_logits, group_mask = calibrate_with_tau_and_w_logits(
            logits=logits,
            features=features,
            tau=tau,
            b=b,
            w_net=w_net,
            hard=hard_group
        )

        main_loss = F.binary_cross_entropy_with_logits(calibrated_logits, labels.float())

        if hard_group:
            group_counts = group_mask.sum(dim=0)
            valid_groups = (group_counts > 1)

            be_vared = calibrated_logits if method_config.var_on_prob else labels
            be_vared = be_vared.view(-1, 1).expand(-1, group_mask.shape[1])
            group_means = (be_vared * group_mask).sum(dim=0) / (group_counts + 1e-6)
            group_vars = (group_mask * (be_vared - group_means) ** 2).sum(dim=0) / (group_counts + 1e-6)
            group_vars = group_vars[valid_groups]

            weights = group_counts[valid_groups] / group_counts.sum()
            var_loss = (weights * group_vars).sum()
        else:

            group_weights = group_mask.sum(dim=0)
            valid_groups = (group_weights > 1e-6)

            be_vared = calibrated_logits if method_config.var_on_prob else labels
            means = (group_mask * be_vared.view(-1, 1)).sum(dim=0) / (group_weights + 1e-6)
            variances = (group_mask * (be_vared.view(-1, 1) - means) ** 2).sum(dim=0) / (group_weights + 1e-6)
            var_loss = (variances[valid_groups] * group_weights[valid_groups]).sum() / group_weights.sum()

        _loss = (main_loss + reg_weight_decay_loss) * (1-method_config.var_lambda) + method_config.var_lambda * var_loss

        _loss.backward()
        return _loss

    optimizer.step(closure=closure)

    return tau.detach().cpu(), b.detach().cpu(), w_net.cpu()


def train_partitions(features,
                     logits,
                     labels,
                     w_net,
                     method_config):
    w_net_list = []
    print("Generating partitions...")
    for partition_i in tqdm(range(method_config.num_partitions)):
        trained_tau, trained_b, trained_w_net = optimize_group_fn(features,
                                                       logits,
                                                       labels,
                                                       hard_group=method_config.hard_group,
                                                       w_net=w_net,
                                                       method_config=method_config)
        w_net_list.append(trained_w_net)
    return w_net_list


def calibrate(val_features,
              val_logits,
              val_labels,
              test_train_features,
              test_train_logits,
              test_train_labels,
              test_test_features,
              test_test_logits,
              method_config,
              base_calibrate_fn,
              seed,
              cfg,
              *args, **kwargs):
    """
    Convenience function for grouping with var-p-loss

    Returns:
        Calibrated probabilities for test_logits
    """
    w_net_list = train_partitions(val_features,
                                  val_logits,
                                  val_labels,
                                  w_net="linear",
                                  method_config=method_config)
    calibrated_probs = []

    for trained_w_net in tqdm(w_net_list):

        train_group_logits = trained_w_net(test_train_features)
        test_group_logits = trained_w_net(test_test_features)
        # Hard group
        train_groups_id = torch.argmax(
            train_group_logits, dim=1)
        test_groups_id = torch.argmax(
            test_group_logits, dim=1)

        _calibrated_probs = torch.zeros_like(test_test_logits)

        for _g in range(method_config.num_groups):
            train_group_mask = train_groups_id == _g
            test_group_mask = test_groups_id == _g
            group_train_logits = test_train_logits[train_group_mask]
            group_train_labels = test_train_labels[train_group_mask]

            group_test_logits = test_test_logits[test_group_mask]


            _group_calibrated_prob = base_calibrate_fn(
                train_logits=group_train_logits,
                train_labels=group_train_labels,
                test_logits=group_test_logits
            )['logits']

            _calibrated_probs[test_group_mask] = torch.tensor(
                _group_calibrated_prob, dtype=torch.float64
            ).clone().detach()

        calibrated_probs.append(_calibrated_probs.detach())
    calibrated_probs = torch.stack(calibrated_probs, dim=0).mean(0)


    return {"logits": calibrated_probs, 'group': test_groups_id}
