# Borrowed from https://github.com/nicolas-dufour/diffusion/blob/master/metrics/sample_and_eval.py
import random
import clip
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
from .inception_metrics import MultiInceptionMetrics


def remap_image_torch(image):
    min_norm = image.min(-1)[0].min(-1)[0].min(-1)[0].view(-1, 1, 1, 1)
    max_norm = image.max(-1)[0].max(-1)[0].max(-1)[0].view(-1, 1, 1, 1)
    image_torch = ((image - min_norm) / (max_norm - min_norm)) * 255
    image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8)
    return image_torch

def normalize_image_torch(image):
    min_norm = image.min(-1)[0].min(-1)[0].min(-1)[0].view(-1, 1, 1, 1)
    max_norm = image.max(-1)[0].max(-1)[0].max(-1)[0].view(-1, 1, 1, 1)
    image_torch = ((image - min_norm) / (max_norm - min_norm))
    return image_torch 

class SampleAndEval:
    def __init__(self, device, num_images=50000, compute_per_class_metrics=False, num_classes=100):
        super().__init__()
        self.inception_metrics = MultiInceptionMetrics(
            reset_real_features=False,
            compute_unconditional_metrics=False,
            compute_conditional_metrics=True,
            compute_conditional_metrics_per_class=compute_per_class_metrics,
            num_classes=num_classes,
            num_inception_chunks=10,
            manifold_k=5,
        )
        self.num_images = num_images
        self.true_features_computed = False
        self.device = device


    def compute_and_log_metrics(self, real_images, sample_images, labels):
        with torch.no_grad():
            if not self.true_features_computed or not self.inception_metrics.reset_real_features:
                self.compute_images_features(real_images, labels, "real")
                self.true_features_computed = True
            self.compute_images_features(sample_images, labels, "conditional")

            metrics = self.inception_metrics.compute()
            # metrics = {f"Eval/{k}": v for k, v in metrics.items()}
            # print(metrics)
            return metrics
    
    def compute_images_features(self, images_loader, labels_loader, image_type):
        
        for i in range(len(images_loader)):
            
            # images for xqgan test -1,1(not strict) for eval 0, 1
            images, labels = images_loader[i], labels_loader[i]

            # images = torch.from_numpy(images).to(self.device)
            #labels = torch.from_numpy(labels).to(self.device)

            # images = images.to(torch.uint8)
            #labels = labels.long()
            
            self.inception_metrics.update(
                                          remap_image_torch(images),
                                          labels,
                                          image_type=image_type)

        