from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from data.sampler import SubsetSequentialSampler
from tqdm import tqdm
from functions import *
import torch
import numpy as np

def GMM_Log_Likelihood(GMM_model, embedding, sample_size, pi, mean, logvar):
    assert mean.shape == logvar.shape
    batch_size, latent_size = embedding.shape

    _log_likelihoods = GMM_model.gaussian_log_prob(
        embedding[:, None, :].repeat(1, GMM_model.component_size, 1),
        mean[None, :, :].repeat(batch_size, 1, 1),
        logvar[None, :, :].repeat(batch_size, 1, 1)
    )

    if _log_likelihoods == None:
        return None
    else:
        log_likelihoods = _log_likelihoods + torch.log(pi[None, :].repeat(batch_size, 1))
        LogSumExp = torch.logsumexp(log_likelihoods , dim=-1)

    return LogSumExp

def Likelihood_Sampling_DAAL(data_train, unlabeled_set, BATCH, models, GMM_model, Phis=None, **kwargs):
    L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var, Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var = Phis
    UL_selection_loader = DataLoader(data_train, batch_size=BATCH, sampler=SubsetSequentialSampler(unlabeled_set), pin_memory=True)
    ALL_labels = np.array(data_train.targets) if 'targets' in dir(data_train) else data_train.labels
    UL_probs, UL_acces, UL_labels, idxFromIndices, UL_likelihoods_L, UL_likelihoods_UL, UL_preds, UL_scores, UL_critics = [], [], [], [], [], [], [], [], []
    UL_margins, UL_entropies = [], []
    with torch.no_grad():
        with tqdm(total=len(UL_selection_loader)) as pbar:
            for i, data in enumerate(UL_selection_loader):
                pbar.update(1)
                idxFromIndex = unlabeled_set[i * BATCH:(i + 1) * BATCH]
                inputs = data[0].cuda(); label = data[1].cuda()
                assert np.array_equal(ALL_labels[idxFromIndex], label.cpu().numpy())
                scores, embedding, features = models['backbone'](inputs)
                UL_pred = scores.argmax(dim=-1)
                UL_scores.append(scores)
                UL_acces.append((UL_pred == label).long())
                # ============ distCat as input for Critic ===========#
                psi_xl = L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var
                psi_xul = Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var
                z = embedding.unsqueeze(dim=1)
                distCat = GetDistInfo_V2(psi_xl, psi_xul, z, GMM_model)
                # ============ distCat as input for Critic ===========#
                critic_score = models['critic'](distCat)
                UL_critics.append(critic_score)
                Likelihood_L = GMM_Log_Likelihood(GMM_model, embedding, 1, L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var)
                Likelihood_UL = GMM_Log_Likelihood(GMM_model, embedding, 1, Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var)
                UL_likelihoods_L.append(Likelihood_L)
                UL_likelihoods_UL.append(Likelihood_UL)
                UL_preds.append(UL_pred)
                UL_labels.append(label)
                idxFromIndices.extend(idxFromIndex)
                UL_prob = F.softmax(scores)
                UL_entropy = -(UL_prob * UL_prob.log()).sum(dim=1)
                sort_prob = UL_prob.sort(dim=1)[0]
                margin = sort_prob[:, -1] - sort_prob[:, -2]
                UL_entropies.append(UL_entropy)

        UL_critics = torch.cat((UL_critics)).reshape(-1).cpu().numpy()
        idxFromIndices = torch.tensor(idxFromIndices)
        if 'SVHN' in repr(data_train):
            UL_preds = torch.cat(UL_preds).reshape(-1).cpu().numpy()
            UL_likelihoods_L = torch.cat(UL_likelihoods_L).reshape(-1).cpu().cpu().numpy()
            UL_likelihoods_UL = torch.cat(UL_likelihoods_UL).reshape(-1).cpu().cpu().numpy()
            UL_L1 = UL_likelihoods_UL - UL_likelihoods_L
            UL_labels = torch.cat(UL_labels).reshape(-1).cpu().numpy()
            UL_acces = torch.cat(UL_acces).reshape(-1).cpu().numpy()
            UL_scores = torch.cat(UL_scores).reshape(-1, len(L_ALL_GMM_pi))
            UL_entropies = torch.cat(UL_entropies).reshape(-1).cpu().numpy()
        else:
            UL_preds = torch.stack(UL_preds).reshape(-1).cpu().numpy()
            UL_likelihoods_L = torch.stack(UL_likelihoods_L).reshape(-1).cpu().cpu().numpy()
            UL_likelihoods_UL = torch.stack(UL_likelihoods_UL).reshape(-1).cpu().cpu().numpy()
            UL_L1 = UL_likelihoods_UL - UL_likelihoods_L
            UL_labels = torch.stack(UL_labels).reshape(-1).cpu().numpy()
            UL_acces = torch.stack(UL_acces).reshape(-1).cpu().numpy()
            UL_scores = torch.stack(UL_scores).reshape(-1, len(L_ALL_GMM_pi))
            UL_entropies = torch.stack(UL_entropies).reshape(-1).cpu().numpy()

        [UL_L1, UL_critics, UL_entropies] = list(map(Ranking, [UL_L1, UL_critics, UL_entropies]))
        Info = UL_L1 + UL_critics + UL_entropies

        pre_arg = np.argsort(-Info)
        arg = idxFromIndices[pre_arg]
        UL_arg_preds = UL_preds[pre_arg]
        UL_arg_labels = UL_labels[pre_arg]
        UL_arg_scores = UL_scores[pre_arg]
        assert np.array_equal(UL_labels[pre_arg], ALL_labels[arg])

    print(f'UL_acc (calculated during active selection) is {UL_acces.mean()}')
    return arg, UL_arg_preds, UL_arg_labels, UL_arg_scores

