import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from sklearn.semi_supervised import LabelSpreading
from copy import deepcopy
import numpy as np

# Main LP Class for Label Propagation
class OT3A(nn.Module):
    def __init__(self, model, optimizer_type, prior, lr=1e-4, device=None, tp=0.25, tc=0.7):
        """
        Initializes the Label Propagation Model.

        Args:
            model: The base model used for feature extraction and prediction.
            optimizer_type: The optimizer type to be used, e.g., torch.optim.Adam.
            prior: Initial prior distribution for correcting outputs.
            lr: Learning rate for the optimizer.
            device: Device to run the model ("cuda" or "cpu").
            tp: Tail proportion for pseudo-label adjustment.
            tc: Consistency threshold used for sample filtering.
        """
        super().__init__()
        self.base_model1 = deepcopy(model)
        self.optimizer = optimizer_type(self.base_model1.parameters(), lr=lr)
        self.prior = prior
        self.source_y = prior  # Saves the source domain label distribution
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")  # Set default device
        self.tp = tp
        self.tc = tc
        self.optimizer_type = optimizer_type
        print(f"Model initialized on device: {self.device}")

    # Helper Functions
    def calculate_distances(self, features):
        """
        Computes pairwise distances between feature samples.

        Args:
            features: Feature matrix, shape (n_samples, n_features).

        Returns:
            batch_distances: Pairwise distance matrix, shape (n_samples, n_samples).
        """
        return pairwise_distances(features.detach().cpu().numpy())

    
    def filter_consistent_samples(self, features, pseudo_labels, given_indices, tc=0.3):
        """
        Filters samples based on pseudo-label consistency with their nearest neighbors.

        Args:
            features: Input feature matrix, shape (n_samples, n_features).
            pseudo_labels: Pseudo labels for samples, shape (n_samples,).
            given_indices: Indices of samples to check.
            tc: Threshold for consistency filtering.

        Returns:
            remain_index: Filtered sample indices.
        """
        batch_distances = self.calculate_distances(features)
        mean_distance = batch_distances.mean()

        # Compute nearest neighbor labels and consistency scores
        nearby_neighbors = np.where(batch_distances > mean_distance, 0, 1)
        denominator = nearby_neighbors.sum(axis=1)
        # Avoid division by zero by replacing 0 with a very small positive value
        denominator = np.where(denominator == 0, 1e-8, denominator)
        nn_labels = ((pseudo_labels * nearby_neighbors).sum(axis=1)) / denominator
        # nn_labels = ((pseudo_labels * nearby_neighbors).sum(axis=1)) / nearby_neighbors.sum(axis=1)
        consistency = abs(nn_labels - pseudo_labels)
        # consistency = torch.tensor(consistency).clone().detach()  # Fix UserWarning here

        gamma0 = torch.quantile(consistency[given_indices][pseudo_labels[given_indices] == 0], tc)
        gamma1 = torch.quantile(consistency[given_indices][pseudo_labels[given_indices] == 1], tc)

        index0 = torch.where(consistency[given_indices][pseudo_labels[given_indices] == 0] <= gamma0)[0]
        index1 = torch.where(consistency[given_indices][pseudo_labels[given_indices] == 1] <= gamma1)[0]

        remain_indices = torch.cat([
            given_indices[pseudo_labels[given_indices] == 0][index0],
            given_indices[pseudo_labels[given_indices] == 1][index1]
        ])
        return remain_indices

    def adjust_pseudo_labels(self, logits_p, tp=0.25):
        """
        Adjusts pseudo labels by flipping low-confidence tail samples.

        Args:
            logits_p: Model output probabilities, shape (n_samples, n_classes).
            tp: Tail proportion threshold for pseudo-label adjustment.

        Returns:
            y_hat: Adjusted pseudo labels.
            ind_tail: Indices of tail samples adjusted.
            ind_head: Indices of head samples.
        """
        y_hat = logits_p.argmax(axis=1)  # Initial pseudo-labels
        diff_abs = torch.abs(logits_p[:, 0] - logits_p[:, 1])  # Difference in two class probabilities

        # Tail adjustment
        ind_0 = torch.where(y_hat == 0)[0]
        ind_1 = torch.where(y_hat == 1)[0]
        diff0 = diff_abs[y_hat == 0]
        diff1 = diff_abs[y_hat == 1]
        ind_tail = torch.cat([
            ind_0[torch.where(diff0 < torch.quantile(diff0, tp))[0]],
            ind_1[torch.where(diff1 < torch.quantile(diff1, tp))[0]]
        ])
        y_hat[ind_tail] = 1 - y_hat[ind_tail]  # Flip tail pseudo-labels

        # Head adjustment
        ind_head = torch.cat([
            ind_0[torch.where(diff0 > torch.quantile(diff0, 1 - tp))[0]],
            ind_1[torch.where(diff1 > torch.quantile(diff1, 1 - tp))[0]]
        ])

        return y_hat, ind_tail, ind_head

    def label_propagation(self, features, labels):
        """
        Executes label propagation using semi-supervised learning.

        Args:
            features: Input feature matrix, shape (n_samples, n_features).
            labels: Labels for samples, including -1 for unknown samples.

        Returns:
            label_lp: Labels predicted after propagation.
        """
        LP = LabelSpreading(gamma=1.0, max_iter=1000)
        LP.fit(features.detach().cpu().numpy(), labels.cpu().numpy())
        return torch.tensor(LP.predict(features.detach().cpu().numpy()))
    
    def compute_zeta(self, label_mixed):
        """
        Computes the class imbalance metric `zeta`.

        `zeta` measures the imbalance between the number of samples assigned to class 0 and
        class 1 in the pseudo-labeled data. A balanced dataset would have zeta close to 0.

        Args:
            label_mixed: Mixed pseudo-labels after filtering, where valid labels ≥ 0, and invalid labels are -1.

        Returns:
            zeta: Class imbalance value representing deviation from perfect balance.
        """
        K = (label_mixed >= 0).sum()  # Total valid pseudo-labels
        if K == 0:  # Avoid division by zero
            return torch.tensor(0.0)

        K0 = (label_mixed == 0).sum()  # Count of class 0 samples
        K1 = (label_mixed == 1).sum()  # Count of class 1 samples
        zeta = abs(K0 / K - 0.5) + abs(K1 / K - 0.5)  # Deviation from balanced distribution
        return zeta

    # Forward Pass
    def forward(self, x):
        """
        Forward pass for label propagation and pseudo-label adjustment.

        Args:
            x: Input feature matrix.

        Returns:
            final[:, 1]: Adjusted class probabilities.
        """
        logits_p = torch.sigmoid(self.base_model1(x))
        logits_p = torch.cat([1 - logits_p, logits_p], dim=1)

        # Adjust pseudo-labels and select high-confidence samples
        y_hat, ind_tail, ind_head = self.adjust_pseudo_labels(logits_p, self.tp)
        consistent_indices = torch.cat([ind_tail, ind_head])
        consistent_samples = self.filter_consistent_samples(x, y_hat, consistent_indices, tc=self.tc)

        # Create mixed labels and propagate
        mixed_labels = torch.full(y_hat.size(), -1, dtype=torch.long)
        mixed_labels[consistent_samples] = y_hat[consistent_samples]
        zeta = self.compute_zeta(mixed_labels)
        propagated_labels = self.label_propagation(x, mixed_labels)
        prior = torch.tensor([
            (propagated_labels == 0).sum() / len(propagated_labels),
            (propagated_labels == 1).sum() / len(propagated_labels)
        ])
        # Update prior distribution
        self.prior = (1-zeta) * prior + zeta * logits_p.mean(0)
        adjusted_logits =  F.normalize(logits_p * self.prior / self.source_y,p=1)

        return adjusted_logits[:, 1].detach()