import torch

import random
import numpy as np
from copy import deepcopy
from sklearn.cluster import KMeans


class MispredictionAnalyzer:
    def __init__(self, model, dataset, criterion, sample_size=None, task='classification', regression_threshold=1e-5):
        self.task = task

        # obtain all / a subset of inputs and their corresponding true labels from the dataset
        inputs, true_labels = self._load_samples(dataset, sample_size)

        # collect all corresponding outputs from the existing model given the sampled dataset inputs
        outputs = self._collect_model_outputs(model, inputs)

        assert len(inputs) == len(outputs) == len(true_labels), \
            "Unexpected error: length mismatch ({} - {} - {}).".format(len(inputs), len(outputs), len(true_labels))

        # screen out all the mis-predicted instances and collect their indices
        mis_predicted_indices = self._collect_mis_predictions(outputs, true_labels, task, regression_threshold=regression_threshold)

        if len(mis_predicted_indices) > 0:
            # ensure no reduction (such as mean or sum) will be applied when calculating loss function given batch inputs
            loss_function = deepcopy(criterion)
            if hasattr(loss_function, 'reduction'):
                loss_function.reduction = 'none'

            # get mispredicted data
            mis_predicted_outputs = outputs[mis_predicted_indices]
            mis_predicted_labels = true_labels[mis_predicted_indices]
            mis_predicted_inputs = inputs[mis_predicted_indices]
            
            with torch.no_grad():
                losses = loss_function(mis_predicted_outputs, mis_predicted_labels)
                # handle multi-dimensional losses (e.g., per-class losses)
                if losses.dim() > 1:
                    losses = losses.mean(dim=tuple(range(1, losses.dim())))

            self.mis_predicted_inputs = mis_predicted_inputs.cpu().numpy()
            self.mis_predicted_labels = mis_predicted_labels.cpu().numpy()
            self.losses = losses.cpu().numpy()

            assert len(self.losses) == len(self.mis_predicted_inputs) == len(self.mis_predicted_labels), \
                "Numbers of losses ({}) inconsistent with mis_predicted_inputs ({}) and mis_predicted_labels ({}).".format(
                    len(self.losses), len(self.mis_predicted_inputs), len(self.mis_predicted_labels)
                )

        else:
            # no mis-predicted instance among the sampled inputs
            print("All {} sampled instances are correctly predicted.".format(len(inputs)))

            self.mis_predicted_inputs = np.array([])
            self.mis_predicted_labels = np.array([])
            self.losses = np.array([])


    # Create a generator that find a set of the most typical mis-predicted instances (TMPI) by clustering all mis-predictions on the input space based on their loss function
    # and return their indices as well as the index of a representative instance among them
    def find_typical_mistakes(self, n_clusters=20, standardization=False, softmax=True, method='largest_cluster', representative='cluster_centroid'):
        
        if len(self.mis_predicted_inputs) == 0:
            print("No mispredicted instances to cluster.")
            return

        if len(self.mis_predicted_inputs) < n_clusters:
            raise ValueError("Insufficient mis-predicted instances ({}) to be clustered, given the number of clusters {}.".format(
                len(self.mis_predicted_inputs), n_clusters)
            )

        inputs = self.mis_predicted_inputs.copy()
        losses = self.losses.copy()

        # flatten inputs if they are high-dimensional
        if inputs.ndim > 2:
            inputs = inputs.reshape(inputs.shape[0], -1)

        if standardization:
            # standardize inputs among different input dimensions
            from sklearn.preprocessing import StandardScaler
            scaler = StandardScaler()
            inputs = scaler.fit_transform(inputs)

        # normalize losses for use as weights
        if softmax:
            # normalize losses using softmax
            exp_losses = np.exp(losses - np.max(losses))
            losses = exp_losses / (exp_losses.sum() + 1e-10)
        else:
            # normalize losses using min-max scaling
            from sklearn.preprocessing import MinMaxScaler
            scaler = MinMaxScaler()
            losses = scaler.fit_transform(losses.reshape(-1, 1)).flatten()

        # conduct weighted k-means clustering on the input data,
        # with the normalized loss function values of their corresponding predictions as the weights
        kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto").fit(inputs, sample_weight=losses)

        # get clustering results
        cluster_labels = kmeans.labels_

        # get the actual number of clusters by removing indices of empty clusters
        unique_labels, label_counts = np.unique(cluster_labels, return_counts=True)
        n_actual_clusters = len(unique_labels)
        if n_actual_clusters != n_clusters:
            print("Actual number of clusters found: {} (requested: {})".format(n_actual_clusters, n_clusters))

        # select cluster based on method
        if method == 'highest_loss':
            # sort clusters by average loss (descending)
            # so as to pick up the cluster that has the highest values of average loss
            cluster_stats = []
            for label in unique_labels:
                avg_loss = np.mean(losses[cluster_labels == label])
                cluster_stats.append((label, avg_loss))
            cluster_stats.sort(key=lambda x: x[1], reverse=True)
            cluster_order = [stat[0] for stat in cluster_stats]
        elif method == 'largest_cluster':
            # sort clusters by size (descending)
            # so as to pick up the cluster with most samples
            sorted_indices = np.argsort(label_counts)[::-1]
            cluster_order = unique_labels[sorted_indices]
        else:
            raise ValueError("Invalid method: \"{}\". Must be \"highest_loss\" or \"largest_cluster\".".format(method))

        # Yield the top most typical clusters in order
        for offset in range(n_actual_clusters):
            # pick up the top cluster based on the offset applied
            cluster_index = cluster_order[offset]

            # get indices of samples in selected cluster
            cluster_member_indices = np.where(cluster_labels == cluster_index)[0] # np.where() returns a tuple
            typical_mispredictions = self._get_pairs_by_indices(cluster_member_indices)

            # select representative instance
            if representative == 'cluster_centroid':
                # find the instance that is the closest to the cluster centroid as the representative
                center_idx = np.where(unique_labels == cluster_index)[0][0]
                cluster_centroid = kmeans.cluster_centers_[center_idx]
                #cluster_centroid = kmeans.cluster_centers_[cluster_index] # wrong
                distances = np.linalg.norm(inputs[cluster_member_indices] - cluster_centroid, axis=1)
                closest_index = cluster_member_indices[np.argmin(distances)]
                representative_instance = self._get_pairs_by_indices([closest_index])[0]

            elif representative == 'highest_loss':
                # find the instance that has the highest loss in the cluster as the representative
                highest_loss_index = cluster_member_indices[np.argmax(losses[cluster_member_indices])]
                representative_instance = self._get_pairs_by_indices([highest_loss_index])[0]

            else:
                raise ValueError("Invalid representative: \"{}\". Must be \"cluster_centroid\" or \"highest_loss\".".format(representative))

            yield typical_mispredictions, representative_instance, offset


    # Load samples from dataset with optional subsampling
    def _load_samples(self, dataset, sample_size):
        if sample_size is None:
            # load entire dataset
            sample_indices = list(range(len(dataset)))
        else:
            # sample a subset
            if not isinstance(sample_size, int) or sample_size <= 0:
                raise ValueError("Invalid value for parameter \"sample_size\": {}".format(sample_size))
            if sample_size >= len(dataset):
                print("Warning: sample_size ({}) >= dataset size ({}). Using entire dataset.".format(sample_size, len(dataset)))
                sample_indices = list(range(len(dataset)))
            else:
                sample_indices = random.sample(range(len(dataset)), sample_size)

        inputs, true_labels = [], []
        for i in sample_indices:
            x, y = dataset[i]
            inputs.append(x)
            true_labels.append(y)

        # convert to tensors
        inputs = torch.stack(inputs, dim=0)

        # handle different label types appropriately
        if isinstance(true_labels[0], torch.Tensor):
            true_labels = torch.stack(true_labels, dim=0)
        else:
            # determine appropriate dtype based on task
            if self.task == 'classification':
                dtype = torch.long
            else: # regression
                dtype = torch.float32
            true_labels = torch.tensor(true_labels, dtype=dtype)

        return inputs, true_labels


    # Get model outputs with automatic batching if memory issues occur
    def _collect_model_outputs(self, model, inputs, batch_size=64):
        model.eval()
        device = next(model.parameters()).device

        with torch.no_grad():
            try:
                outputs = model(inputs.to(device))
                return outputs.cpu()
            except RuntimeError as e:
                if 'out of memory' in str(e).lower():
                    print("Out of memory error. Processing in batches of {}...".format(batch_size))

                    # Clear cache if using CUDA
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                    # process in batches
                    output_list = []
                    for batch_inputs in torch.split(inputs, batch_size):
                        batch_inputs = batch_inputs.to(device)
                        batch_outputs = model(batch_inputs)
                        output_list.append(batch_outputs.cpu())

                    return torch.cat(output_list, dim=0)
                else:
                    raise


    # Find the indices of mispredicted instances given the model's outputs and their corresponding true labels
    def _collect_mis_predictions(self, outputs, true_labels, task, regression_threshold=1e-5):
        if task == 'classification':
            if outputs.dim() <= 1:
                raise ValueError("For classification, expected outputs with >1 dimensions, got {}.".format(outputs.dim()))
            predictions = torch.argmax(outputs, dim=-1)
            indices = torch.where(predictions != true_labels)[0] # torch.where() returns a tuple
            return indices

        elif task == 'regression':
            if outputs.dim() > 2:
                raise ValueError("For regression, expected 1D or 2D outputs, got {} dimensions.".format(outputs.dim()))
            # Squeeze if needed
            if outputs.dim() == 2 and outputs.shape[1] == 1:
                outputs = outputs.squeeze(1)
            # Use threshold instead of integer comparison
            mask = torch.abs(outputs - true_labels) > regression_threshold
            indices = torch.where(mask)[0] # torch.where() returns a tuple
            return indices
        
        else:
            raise ValueError("Invalid value for parameter \"task\": {}. Must be \"classification\" or \"regression\".".format(task))


    # Select mis-predicted input instances and their corresponding true labels based on the specified indices
    # and pack them as a list of pairs
    def _get_pairs_by_indices(self, indices):
        pairs = []
        for i in indices:
            x = self.mis_predicted_inputs[i]
            y = self.mis_predicted_labels[i]
            pairs.append(tuple([x, y]))
        return pairs