import torch
import torch.nn.functional as F

from tqdm import tqdm


class WNet(torch.nn.Module):

    def __init__(self, feature_dim, num_groups):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_groups = num_groups

        self.model = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, num_groups, bias=False),
        )

    def forward(self, x):
        x = self.model(x)
        return x


def calibrate_with_tau_and_w_logits(logits,
                                    features,
                                    tau,
                                    hard,
                                    w_net=None,
                                    w_logits=None):
    assert (w_logits is not None) != (w_net is not None)

    N = logits.shape[0]
    if hard:
        num_groups = w_net.num_groups
        group_log_softmax = torch.log_softmax(w_net(features), dim=1)
        group_argmax = torch.argmax(group_log_softmax, dim=1)
        group_hard_prob = F.one_hot(
            group_argmax, num_classes=num_groups).float().view((N, num_groups))

        temp_logits = logits.view((N, 1)) / tau.view((1, num_groups))
        temp_log_sigmoid = F.logsigmoid(temp_logits)
        calibrated_logits = torch.sum(temp_log_sigmoid * group_hard_prob, dim=1)
        return calibrated_logits

    else:

        if w_logits is not None:
            num_groups = w_logits.shape[1]
            group_softmax = torch.log_softmax(
                w_logits, dim=1).view((N, num_groups))
        else:
            num_groups = w_net.num_groups
            group_softmax = torch.softmax(
                w_net(features), dim=1).view((N, num_groups))

        temp_logits = logits.view((N, 1)) / tau.view((1, num_groups))
        temp_log_sigmoid = F.logsigmoid(temp_logits)
        calibrated_logits = torch.logsumexp(group_softmax * temp_log_sigmoid, dim=1)
        return calibrated_logits


def optimize_group_fn(
        features,
        logits,
        labels,
        w_net,
        hard_group,
        method_config):

    if isinstance(w_net, str):
        train_w = True
        assert isinstance(method_config.num_groups, int)
        w_net = WNet(feature_dim=features.shape[1],
                     num_groups=method_config.num_groups)
    else:
        train_w = False
        assert isinstance(w_net, torch.nn.Module)

    tau = torch.nn.Parameter(torch.tensor(
        [0.901] * method_config.num_groups,
        requires_grad=True, device=features.device))

    if train_w:
        params = [tau] + list(w_net.parameters())
    else:
        params = [tau]

    optimizer = torch.optim.LBFGS(params,
                                  line_search_fn="strong_wolfe",
                                  max_iter=method_config.optimizer.steps)

    W_gpu = w_net.to(features.device).double()

    def closure():
        optimizer.zero_grad()

        # Calculate weight decay loss
        reg_weight_decay = 0
        for name, param in W_gpu.named_parameters():
            if "weight" in name:
                reg_weight_decay += torch.mean((param )**2)
        reg_weight_decay_loss = reg_weight_decay * method_config.w_net.weight_decay

        # Calculate NLL loss
        calibrated_logits = calibrate_with_tau_and_w_logits(
            logits=logits,
            features=features,
            tau=tau,
            w_net=W_gpu,
            hard=hard_group
        )

        main_loss = F.binary_cross_entropy_with_logits(calibrated_logits, labels.float())

        # Gather all loss
        _loss = main_loss + reg_weight_decay_loss

        _loss.backward()
        return _loss

    optimizer.step(closure=closure)

    return tau.detach().cpu(), w_net.cpu()


def train_partitions(features,
                     logits,
                     labels,
                     w_net,
                     method_config):
    w_net_list = []
    print("Generating partitions...")
    for partition_i in tqdm(range(method_config.num_partitions)):
        trained_tau, trained_w_net = optimize_group_fn(features,
                                                       logits,
                                                       labels,
                                                       hard_group=method_config.hard_group,
                                                       w_net=w_net,
                                                       method_config=method_config)
        w_net_list.append(trained_w_net)
    return w_net_list




def calibrate(val_features,
                      val_logits,
                      val_labels,
                      test_train_features,
                      test_train_logits,
                      test_train_labels,
                      test_test_features,
                      test_test_logits,
                      method_config,
                      base_calibrate_fn,
                      seed,
                      cfg,
                      *args, **kwargs):
    """
    Convenience function for PCE

    Returns:
        Calibrated probabilities for test_logits
    """
    w_net_list = train_partitions(val_features,
                                  val_logits,
                                  val_labels,
                                  w_net="linear",
                                  method_config=method_config)
    calibrated_probs = []

    for trained_w_net in tqdm(w_net_list):

        train_group_logits = trained_w_net(test_train_features)
        test_group_logits = trained_w_net(test_test_features)
        # Hard group
        train_groups_id = torch.argmax(
            train_group_logits, dim=1)
        test_groups_id = torch.argmax(
            test_group_logits, dim=1)

        _calibrated_probs = torch.zeros_like(test_test_logits)

        for _g in range(method_config.num_groups):
            train_group_mask = train_groups_id == _g
            test_group_mask = test_groups_id == _g
            group_train_logits = test_train_logits[train_group_mask]
            group_train_labels = test_train_labels[train_group_mask]

            group_test_logits = test_test_logits[test_group_mask]

            _group_calibrated_prob = base_calibrate_fn(
                train_logits=group_train_logits,
                train_labels=group_train_labels,
                test_logits=group_test_logits
            )['logits']

            _calibrated_probs[test_group_mask] = torch.tensor(
                _group_calibrated_prob, dtype=torch.float64
            ).clone().detach()

        calibrated_probs.append(_calibrated_probs.detach())
    calibrated_probs = torch.stack(calibrated_probs, dim=0).mean(0)

    return {"logits": calibrated_probs, 'group': test_groups_id}