import math
from tqdm import tqdm
import numpy as np
import torch
import numpy as np
from pytorch_fid import fid_score
from ..datasets.loaders import get_loader,get_empty_loader
from ..datasets.supervised_dataset import SupervisedDataset
from ..id_estimator.factory import get_id_estimator
from ..clusterer.factory import get_clusterer
from ..utils import clip_dataset,pickle_exists,save_pickle,load_pickle
import prdc

from .metrics_helpers import InceptionHelper
from .ood_helpers import ood_acc

import pdb
from tqdm import tqdm
import torchvision
import matplotlib.pyplot as plt
import os

def clustered_id(module, eval_loader, cluster_cfg, max_class_samples=20000, train_loader=None, gen_samples=-1, gen_batch_size=64,
        cache=None, default_id_estimator="mle", default_cluster_method="kmeans"):
    
    #TODO: diff dataset sizes for diff metrics

    dataloader = eval_loader if train_loader is None else train_loader
    if cluster_cfg["id_estimator"] is None: cluster_cfg["id_estimator"] = default_id_estimator

    # Can't cluster by class with samples
    #TODO implement class clustering by assigning class to be cluster model
    if cluster_cfg["cluster_method"] == "class":
        print("Default cluster method class cannot be used for clustered_id metric, replacing with " + default_cluster_method)
        cluster_cfg["cluster_method"] = default_cluster_method

    if pickle_exists(cluster_cfg["cluster_id_metric_dataset_save"], "./runs/metric_dataset_saves"):
        print(f"Loading clustered_id metric dataset from {cluster_cfg['cluster_id_metric_dataset_save']}")
        cluster_id_metric_dict = load_pickle(cluster_cfg["cluster_id_metric_dataset_save"], "./runs/metric_dataset_saves")
        datasets = cluster_id_metric_dict["datasets"]
        data_id_estimates = cluster_id_metric_dict["data_id_estimates"]
    else:
        # Get cluster_id estimates from the dataset
        clusterer = get_clusterer(cluster_cfg, writer=None, device=module.device)
        clusterer.set_super_dataloaders(dataloader, get_empty_loader(), get_empty_loader())
        clusterer.set_partitions(dataloader, get_empty_loader(), get_empty_loader())
        clusterer.set_dataloaders()
        datasets = [clip_dataset(dl["train"].dataset, max_class_samples) for dl in clusterer.cluster_dataloaders]

        id_estimator = get_id_estimator(cluster_cfg, writer=None)
        data_id_estimates = [id_estimator.estimate_id(dataset, dataset=True)[cluster_cfg["latent_k"]] for dataset in datasets]
        cluster_id_metric_dict = {
            "datasets": datasets,
            "data_id_estimates": data_id_estimates
        }
        print(f"Saving clustered_id metric dataset to {cluster_cfg['cluster_id_metric_dataset_save']}")
        save_pickle(cluster_cfg["cluster_id_metric_dataset_save"], "./runs/metric_dataset_saves", cluster_id_metric_dict)

    # Get sample dataset
    if gen_samples == -1:
        gen_samples = len(dataloader.dataset)

    gendataset = module.get_sample_dataset(gen_samples, gen_batch_size)

    # save_pickle(cluster_cfg["cluster_id_metric_dataset_save"] + "samples", "./runs/metric_dataset_saves", gendataset)

    # gendataset = cluster_id_metric_dict = load_pickle(cluster_cfg["cluster_id_metric_dataset_save"] + "samples", "./runs/metric_dataset_saves")

    genloader = get_loader(gendataset, module.device, gen_batch_size, drop_last=False, set_device=True)

    # Get cluster_id sample estimates
    genclusterer = get_clusterer(cluster_cfg, writer=None, device=module.device)
    genclusterer.set_super_dataloaders(genloader, get_empty_loader(), get_empty_loader())
    genclusterer.set_partitions(genloader, get_empty_loader(), get_empty_loader())
    gen_cluster_lens = [len(partition["train"]) for partition in genclusterer.partitions]
    genclusterer.set_dataloaders()
    gendatasets = [clip_dataset(dl["train"].dataset, max_class_samples) for dl in genclusterer.cluster_dataloaders]

    id_estimator = get_id_estimator(cluster_cfg, writer=None)
    gen_id_estimates = [id_estimator.estimate_id(dataset, dataset=True)[cluster_cfg["latent_k"]] for dataset in gendatasets]

    # Get matching clusters
    def get_dataset_centroid(dataset):
        datapoints = torch.stack([dataset.dataset.inputs[i] for i in dataset.indices])
        bs = datapoints.shape[0]
        return datapoints.reshape(bs, -1).mean(0)

    gendataset_centroids = [get_dataset_centroid(dataset) for dataset in gendatasets]
    dataset_centroids = [get_dataset_centroid(dataset) for dataset in datasets]
    distances = torch.zeros( (len(gendataset_centroids), len(dataset_centroids)) )

    for gen_idx,gendataset_centroid in enumerate(gendataset_centroids):
        for data_idx,dataset_centroid in enumerate(dataset_centroids):
            distances[gen_idx,data_idx] = ((gendataset_centroid-dataset_centroid)**2).sum()

    matchings = []

    for _ in range(distances.shape[0]):
        min_index = (distances == distances.min()).nonzero()[0]
        matchings.append((min_index[0].item(), min_index[1].item()))
        distances[min_index[0], :] = math.inf
        distances[:, min_index[1]] = math.inf

    # Save samples
    if cluster_cfg["clustered_id_samples_save"] is not None:
        print(f"Saving samples to {cluster_cfg['clustered_id_samples_save']}")

        def save_samples(imgs, save_name, NUM_SAMPLES=64, GRID_ROWS = 8):
            save_name = "./sample_vis/" + save_name
            os.makedirs("/".join(save_name.split("/")[:-1]), exist_ok=True)

            imgs = imgs[:NUM_SAMPLES]
            # imgs.clamp_(module.data_min, module.data_max)
            grid = torchvision.utils.make_grid(imgs, nrow=GRID_ROWS, pad_value=1, normalize=True, scale_each=True)
            grid_permuted = grid.permute((1,2,0))

            plt.figure()
            plt.axis("off")
            plt.imshow(grid_permuted.detach().cpu().numpy())
            plt.savefig(save_name)
        
        for match_idx,matching in enumerate(matchings):
            gen_dset, data_dset = gendatasets[matching[0]], datasets[matching[1]]
            save_samples(torch.stack([gen_dset.dataset.inputs[i] for i in gen_dset.indices]), f"{cluster_cfg['clustered_id_samples_save']}/matching_{match_idx}/gensamples.png")
            save_samples(torch.stack([data_dset.dataset.inputs[i] for i in data_dset.indices]), f"{cluster_cfg['clustered_id_samples_save']}/matching_{match_idx}/datasamples.png")

    # Compute id variation
    id_metrics = {}
    sum_difference = 0
    for match_idx,matching in enumerate(matchings):
        gen_id, data_id = gen_id_estimates[matching[0]], data_id_estimates[matching[1]]
        id_difference = data_id-gen_id
        id_metrics[f"gen_id_{match_idx}"] = gen_id
        id_metrics[f"data_id_{match_idx}"] = data_id
        id_metrics[f"id_diff_{match_idx}"] = id_difference
        sum_difference += abs(id_difference)
    id_metrics["clustered_id_diff"] = sum_difference

    return id_metrics

def id(module, eval_loader, cluster_cfg, max_samples=20000, train_loader=None, gen_samples=-1, gen_batch_size=64,
        cache=None, default_id_estimator="mle"):
    print("MAX SAMPLES:", max_samples)
    if cluster_cfg["id_estimator"] is None: cluster_cfg["id_estimator"] = default_id_estimator
    id_estimator = get_id_estimator(cluster_cfg=cluster_cfg, writer=None)

    # Cache dataset ID estimates
    if pickle_exists(cluster_cfg["metric_dataset_save"], "./runs/metric_dataset_saves"):
        print(f"Loading metric dataset from {cluster_cfg['metric_dataset_save']}")
        metric_dataset_dict = load_pickle(cluster_cfg["metric_dataset_save"], "./runs/metric_dataset_saves")
        dataset = metric_dataset_dict["dataset"]
        data_id_estimates = metric_dataset_dict["data_id_estimate"]
        try:
            assert dataset.indices.shape[0] == max_samples, "cached dataset has incorrect number of samples"
        except:
            pdb.set_trace()
    else:
        dataset = eval_loader.dataset if train_loader is None else train_loader.dataset
        dataset = clip_dataset(dataset, max_samples)
        data_id_estimates = id_estimator.estimate_id(dataset, dataset=True)
        metric_dataset_dict = {
            "dataset": dataset,
            "data_id_estimate": data_id_estimates
        }
        print(f"Saving metric dataset to {cluster_cfg['metric_dataset_save']}")
        save_pickle(cluster_cfg["metric_dataset_save"], "./runs/metric_dataset_saves", metric_dataset_dict)

    if gen_samples == -1:
        gen_samples = len(dataset)
  
    gendataset = module.get_sample_dataset(gen_samples, gen_batch_size)
 
    try:
        combined_inputs = torch.cat((dataset.inputs, gendataset.inputs),0)
    except:
        combined_inputs = torch.cat(( torch.stack([dataset.dataset[i][0] for i in dataset.indices]), gendataset.inputs),0)
    combined_dataset = SupervisedDataset("combined_dataset", "test", combined_inputs)
    combined_id_estimates = id_estimator.estimate_id(combined_dataset, dataset=True)

    gen_id_estimates = id_estimator.estimate_id(gendataset, dataset=True)

    id_metrics = {}
    for idx,(gen_id_estimate,data_id_estimate,combined_id_estimate) in enumerate(zip(gen_id_estimates,data_id_estimates,combined_id_estimates)):
        id_metrics.update({
            f"gen_id_estimate_{idx}": gen_id_estimate,
            f"data_id_estimate_{idx}": data_id_estimate,
            f"id_estimate_distance_{idx}": abs(gen_id_estimate-data_id_estimate),
            f"id_estimate_difference_{idx}": data_id_estimate-gen_id_estimate,
            f"combined_id_estimate_{idx}": combined_id_estimate,
            f"combined_id_estimate_difference_{idx}": data_id_estimate-combined_id_estimate,
            f"combined_id_estimate_distance_{idx}": abs(data_id_estimate-combined_id_estimate)
        })

    return id_metrics
     

def fid(module, eval_loader=None, train_loader=None, gen_samples=50000, gen_batch_size=32,
        cache=None):
    """
    Following Heusel et al. (2017), compute FID from the training set if provided.
    """
    dataloader = eval_loader if train_loader is None else train_loader
    gen_samples = dataloader.dataset.inputs.shape[0]
    gen_batch_size=256
    inception = InceptionHelper(module, dataloader, gen_samples, gen_batch_size)

    gen_mu, gen_sigma = inception.compute_inception_stats()

    if cache is None:
        gt_mu, gt_sigma = inception.compute_inception_stats(dataloader)
    elif "gt_feats" not in cache:
        gt_feats = inception.get_inception_features(dataloader)
        cache["gt_feats"] = gt_feats
        gt_mu = np.mean(gt_feats, axis=0)
        gt_sigma = np.cov(gt_feats, rowvar=False)
        cache["gt_stats"] = gt_mu, gt_sigma
    elif "gt_stats" not in cache:
        gt_feats = cache["gt_feats"]
        gt_mu = np.mean(gt_feats, axis=0)
        gt_sigma = np.cov(gt_feats, rowvar=False)
        cache["gt_stats"] = gt_mu, gt_sigma    
    else:
        gt_mu, gt_sigma = cache["gt_stats"]

    return fid_score.calculate_frechet_distance(gen_mu, gen_sigma, gt_mu, gt_sigma)
    
# def fid(module, eval_loader=None, train_loader=None, max_samples=20000, gen_batch_size=64,
#         cache=None):
#     """
#     Following Heusel et al. (2017), compute FID from the training set if provided.
#     """
    
#     dataloader = eval_loader if train_loader is None else train_loader
#     inception = fid_score.InceptionV3().to(module.device) # TODO: submodule pytorch_fid?

#     dataset = clip_dataset(dataloader.dataset, max_samples=max_samples)
#     dataloader = get_loader(dataset, module.device, gen_batch_size, drop_last=False)

#     def get_inception_features(im_loader, generated):
#         # Compute mean and covariance for generated and ground truth iterables
#         if generated:
#             loader_len = max_samples // gen_batch_size
#             loader_type = "generated"
#         else:
#             loader_len = len(dataloader)
#             loader_type = "ground truth"

#         feats = []
        
#         for batch, _, _ in tqdm(im_loader, desc=f"Getting {loader_type} features", leave=False, total=loader_len):

#             # Convert grayscale to RGB
#             if batch.ndim == 3:
#                 batch.unsqueeze_(1)
#             if batch.shape[1] == 1:
#                 batch = batch.repeat(1, 3, 1, 1)
            
#             with torch.no_grad():
#                 batch_feats = inception(batch)[0]

#             batch_feats = batch_feats.squeeze().cpu().numpy()
#             feats.append(batch_feats)

#         return np.concatenate(feats)

#     def compute_inception_stats(im_loader, generated):
#         feats = get_inception_features(im_loader, generated)
#         mu = np.mean(feats, axis=0)
#         sigma = np.cov(feats, rowvar=False)

#         return mu, sigma

#     gendataset = module.get_sample_dataset(max_samples, gen_batch_size)
#     genloader = get_loader(gendataset, module.device, gen_batch_size, drop_last=False, set_device=True)
#     gen_mu, gen_sigma = compute_inception_stats(genloader, True)

#     if cache is None:
#         gt_mu, gt_sigma = compute_inception_stats(dataloader, False)
#     elif "gt_stats" not in cache:
#         gt_mu, gt_sigma = compute_inception_stats(dataloader, False)
#         cache["gt_stats"] = gt_mu, gt_sigma
#     else:
#         gt_mu, gt_sigma = cache["gt_stats"]

#     return fid_score.calculate_frechet_distance(gen_mu, gen_sigma, gt_mu, gt_sigma)


def precision_recall_density_coverage(module, eval_loader=None, train_loader=None, gen_samples=50000, gen_batch_size=64, nearest_k=5,
        cache=None):
    """
    Following Naaem et al. (2020), compute Precision, Recall, Density, Coverage from the training set if provided.
    """
    dataloader = eval_loader if train_loader is None else train_loader
    inception = InceptionHelper(module, dataloader, gen_samples, gen_batch_size)
    
    gen_feats = inception.get_inception_features()

    if cache is None:
        gt_feats = inception.get_inception_features(dataloader)
    elif "gt_feats" not in cache:
        gt_feats = inception.get_inception_features(dataloader)
        cache["gt_feats"] = gt_feats
    else:
        gt_feats = cache["gt_feats"]
        
    return prdc.compute_prdc(gt_feats, gen_feats, nearest_k)


def tabular_fid(module, eval_loader=None, train_loader=None, cache=None):
    """
    Following Heusel et al. (2017), compute FID from the training set if provided.
    """
    dataloader = eval_loader if train_loader is None else train_loader
    # TODO: implement
    pass


def log_likelihood(module, dataloader, cache=None):
    with torch.no_grad():
        return module.log_prob(dataloader).mean()


def l2_reconstruction_error(module, dataloader, cache=None):
    with torch.no_grad():
        return module.rec_error(dataloader).mean()


def loss(module, dataloader, cache=None):
    with torch.no_grad():
        return module.loss(dataloader).mean()

def disc_loss(module, dataloader, cache=None):
    with torch.no_grad():
        return module.disc_loss(dataloader).mean()

def null_metric(module, dataloader, cache=None):
    return 0


def likelihood_ood_acc(
        module,
        is_test_loader,
        oos_test_loader,
        is_train_loader,
        oos_train_loader,
        savedir,
        cache=None,
    ):
    return ood_acc(
        module, is_test_loader, oos_test_loader, is_train_loader, oos_train_loader, savedir,
        low_dim=False, cache=cache
    )


def likelihood_ood_acc_low_dim(
        module,
        is_test_loader,
        oos_test_loader,
        is_train_loader,
        oos_train_loader,
        savedir,
        cache=None,
    ):
    return ood_acc(
        module, is_test_loader, oos_test_loader, is_train_loader, oos_train_loader, savedir,
        low_dim=True, cache=cache
    )
