import itertools

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm


def get_embeddings(dataloader, model, device='cuda'):
    s, e = 0, 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            data, label = data[0].to(device), data[1].to(device)
            q = model(data)
            if label.dim() == 1:
                label = label.unsqueeze(1)
            if i == 0:
                labels = torch.zeros(
                    len(dataloader.dataset),
                    label.size(1),
                    device=device,
                    dtype=label.dtype,
                )
                all_q = torch.zeros(
                    len(dataloader.dataset),
                    q.size(1),
                    device=device,
                    dtype=q.dtype,
                )
            e = s + q.size(0)
            all_q[s:e] = q
            labels[s:e] = label
            s = e
    return all_q, labels


def train_sgd(train_loader, model, optimizers, loss_func, mining_func, device, verbose=True):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        [opt.zero_grad() for opt in optimizers if opt]
        embeddings = model.embedding(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)

        if len(indices_tuple[0]) > 0:
            loss.backward()
            [opt.step() for opt in optimizers if opt]

    if verbose:
        print("Loss = {}, Number of mined triplets = {}".format(
            loss, mining_func.num_triplets))
    return model


class MetricModel(nn.Module):
    def __init__(self, embedding, metric):
        super().__init__()
        self.embedding = embedding
        self.metric = metric
    
    def forward(self, x, y):
        phi_x = self.embedding(x)
        phi_y = self.embedding(y)
        return self.metric(phi_x, phi_y)


def gridsearch_model_params(objective, param_dict, verbose=False):
    best_score = -np.inf
    best_params = {}

    for values in itertools.product(*param_dict.values()):
        combo = dict(zip(param_dict.keys(), values))

        score = objective(**combo)

        if score > best_score:
            best_score = score
            best_params = combo

        if verbose:
            print(f'{score} at {combo}')
    return best_score, best_params
