from os.path import dirname, abspath, exists, join
import math
import os
import shutil

from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torchvision.utils import save_image
from scipy import linalg
from tqdm import tqdm
import torch
import numpy as np
import sklearn.metrics
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def frechet_inception_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        "Training and test mean vectors have different lengths."
    assert sigma1.shape == sigma2.shape, \
        "Training and test covariances have different dimensions."

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)
    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)

def calculate_moments(data_loader, eval_model, batch_size, quantize=False, world_size=1,
                      DDP=False, disable_tqdm=True, fake_feats=None):
   
    eval_model.eval()
    total_instance = len(data_loader.dataset)
    #print("Total Instance",total_instance)

    data_iter = iter(data_loader)
    num_batches = math.ceil(float(total_instance) / float(batch_size))
    if DDP: num_batches = int(math.ceil(float(total_instance) / float(batch_size*world_size)))

    acts = []
    for i in tqdm(range(0, num_batches), disable=disable_tqdm):
        start = i * batch_size
        end = start + batch_size
        try:
            images, labels = next(data_iter)
        except StopIteration:
            break

        images, labels = images.to("cuda"), labels.to("cuda")

        with torch.no_grad():
            embeddings, logits = eval_model.get_outputs(images)
            acts.append(embeddings)

    acts = torch.cat(acts, dim=0)
    # if DDP: acts = torch.cat(losses.GatherLayer.apply(acts), dim=0)
    acts = acts.detach().cpu().numpy()[:total_instance].astype(np.float64)

    mu = np.mean(acts, axis=0)
    sigma = np.cov(acts, rowvar=False)
    return mu, sigma

def calculate_features(data_loader, eval_model, batch_size, quantize=False, world_size=1,
                      DDP=False, disable_tqdm=True, fake_feats=None):
   
    eval_model.eval()
    total_instance = len(data_loader.dataset)
    #print("Total Instance",total_instance)

    data_iter = iter(data_loader)
    num_batches = math.ceil(float(total_instance) / float(batch_size))
    if DDP: num_batches = int(math.ceil(float(total_instance) / float(batch_size*world_size)))

    acts = []
    for i in tqdm(range(0, num_batches), disable=disable_tqdm):
        start = i * batch_size
        end = start + batch_size
        try:
            images, labels = next(data_iter)
        except StopIteration:
            break

        images, labels = images.to("cuda"), labels.to("cuda")

        with torch.no_grad():
            embeddings, logits = eval_model.get_outputs(images)
            acts.append(embeddings)

    acts = torch.cat(acts, dim=0)
    # if DDP: acts = torch.cat(losses.GatherLayer.apply(acts), dim=0)
    #acts = acts.detach().cpu().numpy()[:total_instance]#.astype(np.float64)
    real_embeds = np.array(acts.detach().cpu().numpy(), dtype=np.float64)

    return real_embeds


def calculate_fid(original_mnist_loader,
                 generated_mnist_loader,
                  eval_model,
                  args):
    eval_model.eval()

    m1, s1 = calculate_moments(data_loader=original_mnist_loader,
                                   eval_model=eval_model,
                                   batch_size=args.batch_size)

    m2, s2 = calculate_moments(data_loader=generated_mnist_loader,
                                   eval_model=eval_model,
                                   batch_size=args.batch_size)

    fid_value = frechet_inception_distance(m1, s1, m2, s2)
    return fid_value, m1, s1

def compute_metrics(original_mnist_loader, generated_mnist_loader, eval_model, args):
    nearest_k = 5
    real_embeds = calculate_features(original_mnist_loader, eval_model, args.batch_size)
    print(real_embeds.shape)
    fake_embeds = calculate_features(generated_mnist_loader, eval_model, args.batch_size)
    print(fake_embeds.shape)
    metrics = compute_prdc(real_features=real_embeds, fake_features=fake_embeds, nearest_k=nearest_k)

    prc, rec, dns, cvg = metrics["precision"], metrics["recall"], metrics["density"], metrics["coverage"]
    return prc, rec, dns, cvg


def compute_pairwise_distance(data_x, data_y=None):
    """
    Args:
        data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
        data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
    Returns:
        numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
    """
    if data_y is None:
        data_y = data_x
    dists = sklearn.metrics.pairwise_distances(
        data_x, data_y, metric='euclidean', n_jobs=8)
    return dists


def get_kth_value(unsorted, k, axis=-1):
    """
    Args:
        unsorted: numpy.ndarray of any dimensionality.
        k: int
    Returns:
        kth values along the designated axis.
    """
    indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
    k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
    kth_values = k_smallests.max(axis=axis)
    return kth_values


def compute_nearest_neighbour_distances(input_features, nearest_k):
    """
    Args:
        input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int
    Returns:
        Distances to kth nearest neighbours.
    """
    distances = compute_pairwise_distance(input_features)
    radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
    return radii


def compute_prdc(real_features, fake_features, nearest_k):
    """
    Computes precision, recall, density, and coverage given two manifolds.
    Args:
        real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int.
    Returns:
        dict of precision, recall, density, and coverage.
    """
    try:
        real_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        real_features, nearest_k)
    except Exception as e:
        print(e)
        
    print(real_nearest_neighbour_distances.shape)
    fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        fake_features, nearest_k)
    distance_real_fake = compute_pairwise_distance(
        real_features, fake_features)
    precision = (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).any(axis=0).mean()

    recall = (
            distance_real_fake <
            np.expand_dims(fake_nearest_neighbour_distances, axis=0)
    ).any(axis=1).mean()
    print(recall)

    density = (1. / float(nearest_k)) * (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).sum(axis=0).mean()

    coverage = (
            distance_real_fake.min(axis=1) <
            real_nearest_neighbour_distances
    ).mean()

    return dict(precision=precision, recall=recall,
                density=density, coverage=coverage)


def compute_nearest_neighbour_distances_slice(input_features, nearest_k, K):
    """
    Args:
        input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int
    Returns:
        Distances to kth nearest neighbours.
    """
    if input_features.shape[0] <= K:
        distances = compute_pairwise_distance(input_features)
        radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
    else:
        k = np.ceil( input_features.shape[0]/K) 
        radii = np.array([])
        for i in range(int(k)):
            distances = compute_pairwise_distance(input_features[i*K:np.min((K*(i+1),input_features.shape[0]))],  input_features)
            radii = np.hstack(( radii,get_kth_value(distances, k=nearest_k + 1, axis=-1) ))
    return np.array(radii)

def compute_prdc_slice(real_features, fake_features, nearest_k, K, flag):
    """
    Computes precision, recall, density, and coverage given two manifolds.
    Args:
        real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int.
    Returns:
        dict of precision, recall, density, and coverage.
    """
    
    real_nearest_neighbour_distances = compute_nearest_neighbour_distances_slice(real_features, nearest_k,K)
    fake_nearest_neighbour_distances = compute_nearest_neighbour_distances_slice(fake_features, nearest_k,K)
    
    if fake_features.shape[0] < K:
        distance_real_fake = compute_pairwise_distance(real_features, fake_features)
        precision = (distance_real_fake <np.expand_dims(real_nearest_neighbour_distances, axis=1)).any(axis=0)
        recall = (distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0)).any(axis=1)
        density = (1. / float(nearest_k)) * (distance_real_fake <np.expand_dims(real_nearest_neighbour_distances, axis=1)).sum(axis=0)
        coverage = (distance_real_fake.min(axis=1) <real_nearest_neighbour_distances)
    else:
        kfake = int(np.ceil( fake_features.shape[0]/K )) 
        kreal = int(np.ceil( real_features.shape[0]/K )) 
        precision = np.array([])
        density = np.array([])
        coverage = np.array([])
        recall = np.array([])
        for i in range(kfake):
            distance_real_fake = compute_pairwise_distance(real_features,fake_features[i*K:np.min((K*(i+1),fake_features.shape[0]))] )
            precision = np.hstack(( precision, (distance_real_fake <np.expand_dims(real_nearest_neighbour_distances, axis=1)).any(axis=0) ))
            density = np.hstack(( density, (1. / float(nearest_k)) * (distance_real_fake <np.expand_dims(real_nearest_neighbour_distances, axis=1)).sum(axis=0) ))
        
        for i in range(kreal):
            distance_real_fake = compute_pairwise_distance(real_features[i*K:np.min((K*(i+1),real_features.shape[0]))] ,fake_features)
            recall = np.hstack((recall,  (distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0)).any(axis=1)  ))
            coverage = np.hstack(( coverage, (distance_real_fake.min(axis=1) <real_nearest_neighbour_distances[i*K:np.min((K*(i+1),real_features.shape[0]))])))
        
    if flag == 1:
        
        coverage = coverage.mean()
        recall = recall.mean()
        density = density.mean()
        precision = precision.mean()

    return dict(precision=precision, recall=recall,
                density=density, coverage=coverage)

def mytsne(X1,X2, save_dir, generation, w):

    tsne = TSNE(n_components = 2, random_state=1)
    X1 = torch.squeeze(X1)
    X1 = X1.view(X1.size(0), -1)
    X2 = torch.squeeze(X2)
    X2 = X2.view(X2.size(0), -1)
    
    
    X1 = X1.cpu().numpy()
    X2 = X2.cpu().numpy()
    
    tsne = TSNE(n_components = 2, random_state=1)
    tres = tsne.fit_transform(np.concatenate((X1,X2 ), axis = 0))
    
    plt.close('all')
    colors = ['r']*X1.shape[0] + ['b']*X2.shape[0]
    plt.scatter(tres[:,1], tres[:,0], s = 0.03, color = colors)
    plt.title(f"All-genration={generation} w={w}")
    plt.savefig(save_dir + f"All-genration={generation}w={w}.png")
    
    return tres