import os
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import lightly
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
from PIL import Image
import numpy as np
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb

#from byol_depth_recompute_batch import compute_depth_maps
# code for kNN prediction from here:
# https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb
def knn_predict(feature, feature_bank, feature_labels, classes: int, knn_k: int, knn_t: float):
    """Helper method to run kNN predictions on features based on a feature bank
    Args:
        feature: Tensor of shape [N, D] consisting of N D-dimensional features
        feature_bank: Tensor of a database of features used for kNN
        feature_labels: Labels for the features in our feature_bank
        classes: Number of classes (e.g. 10 for CIFAR-10)
        knn_k: Number of k neighbors used for kNN
        knn_t: 
    """
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    # we do a reweighting of the similarities 
    sim_weight = (sim_weight / knn_t).exp()
    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    sim_labels = sim_labels.to(torch.int64)
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

def generate_embeddings(model, dataloader):
    """Generates representations for all images in the dataloader with
    the given model
    """

    embeddings = []
    filenames = []
    targets = []
    device = 'cuda'
    with torch.no_grad():
        for img, label, fnames in dataloader:
            img = img.cuda()
            emb = model.backbone(img).flatten(start_dim=1)
            #feature = emb
            feature = F.normalize(emb, dim=1)
            embeddings.append(feature)
            filenames.extend(fnames)
            targets.append(label)

    embeddings = torch.cat(embeddings, 0).t().contiguous().cpu()
    #print(embeddings.size())
    #embeddings = torch.Tensor(normalize(embeddings.t().cpu().numpy()))
    targets = torch.cat(targets, 0).contiguous()
    return embeddings, filenames, targets

def generate_embeddings_depth(model, dataloader, dpt_model, dpt_transform, after_transform):
    """Generates representations for all images in the dataloader with
    the given model
    """

    embeddings = []
    filenames = []
    targets = []
    device = 'cuda'
    with torch.no_grad():
        for img, label, fnames in dataloader:
            d0 = compute_depth_maps(img, dpt_model, dpt_transform)
            new_x0 = after_transform(img)
            img = torch.cat((new_x0, d0), 1)
            img = img.cuda()
            emb = model.backbone(img).flatten(start_dim=1)
            #feature = emb
            feature = F.normalize(emb, dim=1)
            embeddings.append(feature)
            filenames.extend(fnames)
            targets.append(label)

    embeddings = torch.cat(embeddings, 0).t().contiguous().cpu()
    #print(embeddings.size())
    #embeddings = torch.Tensor(normalize(embeddings.t().cpu().numpy()))
    targets = torch.cat(targets, 0).contiguous()
    return embeddings, filenames, targets


def get_image_as_np_array(filename: str):
    """Returns an image as an numpy array
    """
    img = Image.open(filename)
    return np.asarray(img)


