import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt


class OODMetrics:
    """OOD Evaluation Metrics"""

    def __init__(self):
        self.metrics_history = {
            'accuracy': [],
            'f1_score': [],
            'auroc': [],
            'invariant_acc': [],
            'variant_acc': []
        }

    def compute_classification_metrics(self, predictions, labels):
        """Compute classification metrics"""
        if isinstance(predictions, torch.Tensor):
            predictions = predictions.cpu().numpy()
        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()

        # Accuracy
        accuracy = accuracy_score(labels, predictions.argmax(axis=1))

        # F1 score
        f1 = f1_score(labels, predictions.argmax(axis=1), average='weighted')

        # AUROC (for binary classification)
        if len(np.unique(labels)) == 2:
            auroc = roc_auc_score(labels, predictions[:, 1])
        else:
            # Multi-class using macro-average AUROC
            auroc = roc_auc_score(labels, predictions, multi_class='ovo', average='macro')

        return {
            'accuracy': accuracy,
            'f1_score': f1,
            'auroc': auroc
        }

    def compute_invariance_metrics(self, env_predictions, labels):
        """Compute invariance metrics"""
        accuracies = []

        for env_id, preds in env_predictions.items():
            acc = accuracy_score(labels.cpu().numpy(), preds.argmax(dim=1).cpu().numpy())
            accuracies.append(acc)

        # Average accuracy and standard deviation
        mean_acc = np.mean(accuracies)
        std_acc = np.std(accuracies)

        # Invariance score (1 - variance)
        invariance_score = 1.0 - std_acc

        return {
            'mean_accuracy': mean_acc,
            'std_accuracy': std_acc,
            'invariance_score': invariance_score
        }

    def compute_disentanglement_metrics(self, invariant_embs, variant_embs):
        """Compute disentanglement evaluation metrics"""
        from scipy.stats import pearsonr

        # 1. Mutual information estimation
        mi_scores = []
        for env_id in invariant_embs.keys():
            inv = invariant_embs[env_id].cpu().numpy()
            var = variant_embs[env_id].cpu().numpy()

            # Simplified mutual information estimation (based on correlation)
            mi = 0
            for i in range(min(inv.shape[1], var.shape[1])):
                corr, _ = pearsonr(inv[:, i], var[:, i])
                mi += abs(corr)

            mi_scores.append(mi / min(inv.shape[1], var.shape[1]))

        # 2. Environment discrimination
        env_discrimination = self.compute_environment_discrimination(variant_embs)

        return {
            'avg_mutual_info': np.mean(mi_scores),
            'env_discrimination': env_discrimination,
            'disentanglement_score': 1.0 - np.mean(mi_scores)  # Lower mutual information is better
        }

    def compute_environment_discrimination(self, variant_embs):
        """Environment discrimination score"""
        from sklearn.svm import SVC
        from sklearn.model_selection import cross_val_score

        # Prepare data
        X = []
        y = []

        for env_id, embs in variant_embs.items():
            X.append(embs.cpu().numpy())
            y.extend([env_id] * len(embs))

        X = np.vstack(X)
        y = np.array(y)

        # Use SVM classifier
        if len(np.unique(y)) > 1:
            clf = SVC(kernel='linear')
            scores = cross_val_score(clf, X, y, cv=5)
            return scores.mean()
        else:
            return 0.0

    def compute_prototype_alignment(self, invariant_embs, labels, prototypes):
        """Prototype alignment evaluation"""
        alignment_scores = []

        for class_id in range(len(prototypes)):
            class_mask = (labels == class_id)
            if class_mask.any():
                # Collect representations of this class from all environments
                class_embs = []
                for env_embs in invariant_embs.values():
                    class_embs.append(env_embs[class_mask])

                if len(class_embs) > 0:
                    all_class_embs = torch.cat(class_embs, dim=0)

                    # Calculate average distance to prototype
                    prototype = prototypes[class_id]['mean']
                    distances = torch.norm(all_class_embs - prototype, dim=1)
                    avg_distance = distances.mean().item()

                    # Alignment score (smaller distance is better, convert to 0-1 score)
                    alignment_score = 1.0 / (1.0 + avg_distance)
                    alignment_scores.append(alignment_score)

        if len(alignment_scores) > 0:
            return np.mean(alignment_scores)
        else:
            return 0.0

    def update_history(self, metrics_dict):
        """Update history records"""
        for key, value in metrics_dict.items():
            if key in self.metrics_history:
                self.metrics_history[key].append(value)

    def plot_metrics(self, save_path=None):
        """Plot metrics curves"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Accuracy curve
        axes[0, 0].plot(self.metrics_history['accuracy'])
        axes[0, 0].set_title('Accuracy')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Accuracy')

        # F1 score curve
        axes[0, 1].plot(self.metrics_history['f1_score'])
        axes[0, 1].set_title('F1 Score')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('F1 Score')

        # Invariance score
        axes[1, 0].plot(self.metrics_history['invariant_acc'])
        axes[1, 0].set_title('Invariance Score')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Score')

        # Disentanglement score
        if 'disentanglement_score' in self.metrics_history:
            axes[1, 1].plot(self.metrics_history['disentanglement_score'])
            axes[1, 1].set_title('Disentanglement Score')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Score')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()
