import torch
import numpy as np
from scipy.optimize import linear_sum_assignment as linear_assignment
import network
import timm
import torch.nn as nn

def _hungarian_match(flat_preds, flat_targets, preds_k, targets_k):
    assert isinstance(flat_preds, torch.Tensor) and isinstance(flat_targets, torch.Tensor) and flat_preds.is_cuda and flat_targets.is_cuda
  
    num_samples = flat_targets.shape[0]
  
    assert (preds_k == targets_k)  # one to one
    num_k = preds_k
    num_correct = np.zeros((num_k, num_k))
  
    for c1 in range(num_k):
      for c2 in range(num_k):
        # elementwise, so each sample contributes once
        votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())
        num_correct[c1, c2] = votes
  
    # num_correct is small
    match = linear_assignment(num_samples - num_correct)
    
    # return as list of tuples, out_c to gt_c
    res = []
    for out_c, gt_c in zip(match[0], match[1]):
      res.append((out_c, gt_c))
  
    return res

def class_predictions_to_ground_truth(predicted_clusters, ground_truth_clusters, num_protos, class_num):
    '''
    Computes predicted ground truth class of the cluster assignment
    Can be called only if the number of clusters is the same as the ground truth clusters
    :param predicted_clusters:
    :param ground_truth_clusters:
    :return: tensor of predictions translated to ground truth labels
    '''

    num_samples = predicted_clusters.shape[0]

    match = _hungarian_match(predicted_clusters.cuda(), ground_truth_clusters.cuda(), num_protos, class_num)

    found = torch.zeros(num_protos)
    reordered_preds = torch.zeros(num_samples, dtype=predicted_clusters.dtype).cuda()

    for pred_i, target_i in match:
        # reordered_preds[flat_predss_all[i] == pred_i] = target_i
        reordered_preds[torch.eq(predicted_clusters, int(pred_i))] = torch.from_numpy(
            np.array(target_i)).cuda().int().item()
        found[pred_i] = 1
    assert (found.sum() == num_protos)  # each output_kz must get mapped

    return reordered_preds

def compute_accuracy(predicted_clusters, ground_truth_clusters, num_protos, class_num):
    '''
    Computes accuracy of the cluster assignment
    Can be called only if the number of clusters is the same as the ground truth clusters
    :param predicted_clusters:
    :param ground_truth_clusters:
    :return: Cluster accuracy score
    '''
    assert num_protos == class_num
    num_samples = predicted_clusters.shape[0]

    reordered_preds = class_predictions_to_ground_truth(predicted_clusters, ground_truth_clusters, num_protos, class_num)

    accuracy = int((reordered_preds.cuda() == ground_truth_clusters.cuda()).sum()) / float(num_samples)
    return accuracy

def get_model(model_name, init_type):
    if model_name[0:3] == 'res':
        backbone = network.ResBase(res_name=model_name, init_type=init_type).cuda()
    else:
        if init_type == "sup": 
            # load models from pytorch-image-models
            backbone = timm.create_model(model_name, pretrained=True)
        elif init_type == "ssl":
            backbone = vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 
        try:
            backbone.in_features = backbone.get_classifier().in_features
            backbone.reset_classifier(0, '')
        except:
            backbone.in_features = backbone.num_features
            backbone.head = nn.Identity()
    return backbone   

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

