"""
Core implementation of OnlineRDS class
"""

import numpy as np
import torch
from sklearn import metrics
from .thresholds import find_threshold_otsu, filter_outliers

class OnlineRDS:
    """
    Calculate Online Relative Distance Score (RDS)
    
    Calculate the relative distance score for incoming data in batches,
    incrementally learning ID/OOD centers, and using them to compute the relative distance score.
    """
    def __init__(self, 
                 layer_name=None,
                 confidence_threshold=0.9,  # threshold for confidence-based updates
                 use_confidence=False,      # use confidence-based updates (default: False)
                 outlier_removal=True,     # use outlier removal (default: False)
                 iqr_factor=1.5,            # outlier removal factor
                 ema_alpha=0.9,             # center update coefficient
                 feature_dim=640,           # feature vector dimension (WRN: 640)
                 space='feature',           # space for distance calculation ('feature' or 'logit')
                 auto_correction=False,     # enable automatic confidence-based correction mechanism
                 init_methods=['msp','energy','entropy'],
                 temperature=5.0,
                 flip_weight=1.2,
                 correct_flip=True,
                 check_interval=10):     # initialization method ('energy' or 'max_prob')
        """
        Args:
            confidence_threshold (float): confidence threshold for center updates
            use_confidence (bool): whether to use confidence-based sample selection (default: False)
            outlier_removal (bool): whether to use outlier filtering (default: False)
            iqr_factor (float): IQR multiplier for outlier detection
            ema_alpha (float): EMA coefficient for center updates (higher means more retention of previous values)
            feature_dim (int): dimension of feature vector
            space (str): space for distance calculation ('feature' or 'logit')
            auto_correction (bool): whether to enable confidence-based automatic correction mechanism
            init_method (str): initialization method ('energy' or 'max_prob')
        """
        self.layer_name = layer_name
        self.confidence_threshold = confidence_threshold
        self.use_confidence = use_confidence
        self.outlier_removal = outlier_removal
        self.iqr_factor = iqr_factor
        self.alpha = ema_alpha
        self.space = space

        # Initialize centers
        self.id_center = None
        self.ood_center = None
        self.initialized = False
        
        # Variables for tracking results
        self.batch_aurocs = []
        self.distance_history = {'id': [], 'ood': []}
        self.threshold_history = []
        self.confidence_history = {'id': [], 'ood': []}
        
        # temporary data storage
        self.feature_dim = feature_dim
        self.logit_dim = None  # set in the first batch
        

        self.auto_correction = auto_correction  # enable automatic confidence-based correction mechanism
        self.corrections_applied = []  # track correction application
        self.init_methods = init_methods  # store initialization methods
        self.temperature = temperature
        self.flip_weight = flip_weight
        self.correct_flip = correct_flip
        
        self.auroc = None
        self.fpr = None
        self.tpr = None
        # additional variables for initialization
        self.num_init_batches = 1  # number of batches used for initialization
        self.init_features_buffer = []  # feature vector buffer
        self.init_logits_buffer = []    # logit vector buffer
        self.init_batch_count = 0       # number of batches collected so far
        
        # variables for checking prototype flip
        self.check_interval = check_interval
        self.processed_batch_count = 0  # number of batches processed after initialization
        self.prototype_flips = []  # prototype flip history
           
    def initialize_batch(self, features, logits):
        """
        Separate the first batch of ID/OOD and initialize centers
        
        Args:
            features (torch.Tensor): feature vector (feature space)
            logits (torch.Tensor): logit vector (logit space)
            
        Returns:
            tuple: ID/OOD prediction labels, scores
        """
        if self.logit_dim is None:
            self.logit_dim = logits.size(1)
        
        # calculate scores for ID/OOD classification based on initialization method
        if self.init_method == 'energy':
            # calculate energy scores (-logsumexp)
            temperature = 1.0
            energy = -torch.logsumexp(logits/temperature, dim=1).detach().cpu().numpy()
            scores = -energy  # basic score (higher is ID)
            print(f"Initializing with energy scores")
        elif self.init_method == 'max_prob':
            # calculate maximum probability
            temperature = self.temperature
            probs = torch.softmax(logits/temperature, dim=1).detach().cpu().numpy()
            scores = np.max(probs, axis=1)  # higher is ID
            print(f"Initializing with maximum probability scores")
        else:
            raise ValueError(f"Unknown initialization method: {self.init_method}")
        
        # find histogram-based threshold
        threshold = find_threshold_otsu(scores)
        self.threshold_history.append(threshold)
        
        # ID/OOD classification (higher score = ID)
        predictions = (scores > threshold).astype(np.int32)
        
        # calculate confidence
        confidence = np.abs(scores - threshold) / (np.max(np.abs(scores - threshold)) + 1e-10)
        
        # apply automatic correction mechanism (first batch also)
        correction_applied = False
        
        # record correction application
        self.corrections_applied.append(correction_applied)
        
        # select vector for center calculation based on selected space
        if self.space == 'feature':
            vectors = features.detach().cpu().numpy()
        else:  # 'logit'
            vectors = logits.detach().cpu().numpy()
        
        # calculate centers
        id_mask = predictions == 1
        ood_mask = predictions == 0
        
        if id_mask.sum() > 0 and ood_mask.sum() > 0:
            self.id_center = np.mean(vectors[id_mask], axis=0)
            self.ood_center = np.mean(vectors[ood_mask], axis=0)
            self.initialized = True
        else:
            # if separation fails, divide randomly by half
            mid = len(vectors) // 2
            self.id_center = np.mean(vectors[:mid], axis=0)
            self.ood_center = np.mean(vectors[mid:], axis=0)
            self.initialized = True
            
        return predictions, scores  # return corrected scores

    def initialize_batch_ensemble(self, features, logits, scores_init = None):
        """
        Use multiple batches of data to separate ID/OOD using ensemble method and initialize centers
        
        Args:
            features (torch.Tensor): feature vector (feature space)
            logits (torch.Tensor): logit vector (logit space)
            
        Returns:
            tuple: ID/OOD prediction labels, scores
        """
        if self.logit_dim is None:
            self.logit_dim = logits.size(1)
        
        # calculate ensemble scores
        if scores_init is None:
            scores = self._compute_ensemble_scores(features, logits)
        else:
            scores = scores_init
        
        # detect outliers
        outlier_mask, features_np = self._filter_iqr_outliers(features,self.iqr_factor)
        
        # calculate threshold using non-outlier scores
        non_outlier_scores = scores[~outlier_mask]
        if len(non_outlier_scores) > 10:  # only if enough samples
            threshold = find_threshold_otsu(non_outlier_scores)
        else:
            threshold = find_threshold_otsu(scores)  # include outliers
        
        self.threshold_history.append(threshold)
        
        # ID/OOD classification (higher score = ID)
        predictions = (scores > threshold).astype(np.int32)
        
        # force outliers to OOD
        predictions[outlier_mask] = 0
        
        # calculate confidence
        confidence = np.abs(scores - threshold) / (np.max(np.abs(scores - threshold)) + 1e-10)
        
        # apply automatic correction mechanism (first batch also)
        correction_applied = False
        
        # record correction application
        self.corrections_applied.append(correction_applied)
        
        # select vector for center calculation based on selected space
        if self.space == 'feature':
            vectors = features_np
        else:  # 'logit'
            vectors = logits.detach().cpu().numpy()
        
        # calculate
        id_mask = np.logical_and(predictions == 1, ~outlier_mask)
        ood_mask = np.logical_and(predictions == 0, ~outlier_mask)
        
        if id_mask.sum() > 0 and ood_mask.sum() > 0:
            self.id_center = np.mean(vectors[id_mask], axis=0)
            self.ood_center = np.mean(vectors[ood_mask], axis=0)
            self.initialized = True
            
        else:
            # try again with outliers
            id_mask = predictions == 1
            ood_mask = predictions == 0
            
            if id_mask.sum() > 0 and ood_mask.sum() > 0:
                self.id_center = np.mean(vectors[id_mask], axis=0)
                self.ood_center = np.mean(vectors[ood_mask], axis=0)
                self.initialized = True
            else:
                # if separation fails, divide randomly by half
                mid = len(vectors) // 2
                self.id_center = np.mean(vectors[:mid], axis=0)
                self.ood_center = np.mean(vectors[mid:], axis=0)
                self.initialized = True
        
        return predictions, scores  # return corrected scores

    def compute_rds(self, features, logits=None):
        """
        Calculate relative distance score
        
        Args:
            features (torch.Tensor): feature vector
            logits (torch.Tensor, optional): logit vector (required for logit space)
            
        Returns:
            np.ndarray: relative distance score (lower value means higher ID probability)
        """
        # select vector for distance calculation based on selected space
        if self.space == 'feature' or logits is None:
            vectors = features.detach().cpu().numpy()
        else:  # 'logit' and logits is not None
            vectors = logits.detach().cpu().numpy()
        
        # calculate distance to each center
        id_distances = np.array([np.linalg.norm(vec - self.id_center) for vec in vectors])
        ood_distances = np.array([np.linalg.norm(vec - self.ood_center) for vec in vectors])
        
        # calculate relative distance score
        # lower value means closer to ID
        rds = id_distances / (id_distances + ood_distances + 1e-10)
        
        return rds, id_distances, ood_distances
    
    def compute_confidence(self, scores, threshold):
        """
        Calculate confidence based on scores
        
        Args:
            scores (np.ndarray): calculated score array
            threshold (float): current threshold
            
        Returns:
            np.ndarray: confidence for each sample (0~1)
        """
        # calculate confidence based on the difference between scores and threshold
        confidence = np.abs(scores - threshold) / (np.max(np.abs(scores - threshold)) + 1e-10)
        return confidence
    
    def update_centers(self, features, logits, predictions, confidence=None):
        """
        Update centers with new data (EMA)

        Args:
            features (torch.Tensor): feature vector
            logits (torch.Tensor, optional): logit vector
            predictions (np.ndarray): ID(1)/OOD(0) prediction
            confidence (np.ndarray, optional): confidence for each prediction
        """
        # select vector for center update based on selected space
        if self.space == 'feature' or logits is None:
            vectors = features.detach().cpu().numpy()
        else:  # 'logit' and logits is not None
            vectors = logits.detach().cpu().numpy()
        
        id_mask = predictions == 1
        ood_mask = predictions == 0
        
        # confidence-based filtering (optional)
        if confidence is not None and self.use_confidence:
            # filter out bottom 25% of confidence values (instead of fixed threshold)
            confidence_threshold = np.percentile(confidence, 10)
            # confidence_threshold = self.confidence_threshold
            conf_mask = confidence >= confidence_threshold
            id_mask = id_mask & conf_mask
            ood_mask = ood_mask & conf_mask
        
        # outlier removal (optional)
        if self.outlier_removal and np.sum(id_mask) > 10:  # only if enough samples
            id_vectors = vectors[id_mask]
            id_distances = np.array([np.linalg.norm(vec - self.id_center) for vec in id_vectors])
            normal_mask = filter_outliers(id_distances, self.iqr_factor)
            
            # update original mask (be careful with indexing)
            id_indices = np.where(id_mask)[0]
            id_mask[id_indices[~normal_mask]] = False
        
        if self.outlier_removal and np.sum(ood_mask) > 10:
            ood_vectors = vectors[ood_mask]
            ood_distances = np.array([np.linalg.norm(vec - self.ood_center) for vec in ood_vectors])
            normal_mask = filter_outliers(ood_distances, self.iqr_factor)
            
            ood_indices = np.where(ood_mask)[0]
            ood_mask[ood_indices[~normal_mask]] = False
        
        # update centers (EMA)
        if np.sum(id_mask) > 0:
            new_id_center = np.mean(vectors[id_mask], axis=0)
            self.id_center = self.alpha * self.id_center + (1 - self.alpha) * new_id_center
            
        if np.sum(ood_mask) > 0:
            new_ood_center = np.mean(vectors[ood_mask], axis=0)
            self.ood_center = self.alpha * self.ood_center + (1 - self.alpha) * new_ood_center
    
    def check_prototype_flip(self, features, logits, scores_flipDet=None):
        """
        Compare current prototype with initialization method and flip if necessary
        
        Args:
            features (torch.Tensor): feature vector
            logits (torch.Tensor): logit vector
            
        Returns:
            bool: whether the prototype has flipped
        """
        print(f"Batch {self.processed_batch_count}: checking prototype flip...")
        
        # calculate ensemble scores
        if scores_flipDet is None:
            scores = self._compute_ensemble_scores(features, logits)
        else:
            scores = scores_flipDet
        
        # detect outliers
        outlier_mask, features_np = self._filter_iqr_outliers(features,self.iqr_factor)
        
        # calculate threshold using non-outlier scores
        non_outlier_scores = scores[~outlier_mask]
        if len(non_outlier_scores) > 10:
            threshold = find_threshold_otsu(non_outlier_scores)
        else:
            threshold = find_threshold_otsu(scores)
        
        # ID/OOD classification based on scores (excluding outliers)
        predictions = np.zeros_like(scores, dtype=np.int32)
        predictions[~outlier_mask] = (scores[~outlier_mask] > threshold).astype(np.int32)

        # add confidence calculation
        confidence = self.compute_confidence(scores, threshold)
        
        # select vector for center calculation based on selected space
        if self.space == 'feature':
            vectors = features_np
        else:  # 'logit'
            vectors = logits.detach().cpu().numpy()
        
        # ID/OOD mask (excluding outliers)
        id_mask = np.logical_and(predictions == 1, ~outlier_mask)
        ood_mask = np.logical_and(predictions == 0, ~outlier_mask)

        # add confidence-based filtering
        if self.use_confidence:
            confidence_threshold = np.percentile(confidence, 10)
            # confidence_threshold = self.confidence_threshold
            conf_mask = confidence >= confidence_threshold
            id_mask = id_mask & conf_mask
            ood_mask = ood_mask & conf_mask
            print(f"Applied confidence-based filtering: threshold {confidence_threshold}, remaining ID samples {np.sum(id_mask)}, OOD samples {np.sum(ood_mask)}")
        
        if np.sum(id_mask) < 10 or np.sum(ood_mask) < 10:
            print(f"Not enough ID ({np.sum(id_mask)}) or OOD ({np.sum(ood_mask)}) samples, skipping prototype check")
            return False
        
        baseline_id_center = np.mean(vectors[id_mask], axis=0)
        baseline_ood_center = np.mean(vectors[ood_mask], axis=0)
        
        # calculate distance to baseline prototype
        id_to_baseline_id = np.linalg.norm(self.id_center - baseline_id_center)
        ood_to_baseline_id = np.linalg.norm(self.ood_center - baseline_id_center)

        # calculate cosine similarity
        id_to_baseline_id_cos = 1 - np.dot(self.id_center, baseline_id_center) / (np.linalg.norm(self.id_center) * np.linalg.norm(baseline_id_center))
        ood_to_baseline_id_cos = 1 - np.dot(self.ood_center, baseline_id_center) / (np.linalg.norm(self.ood_center) * np.linalg.norm(baseline_id_center))
        
        # detect flip based on cosine similarity
        flip_detected1 = id_to_baseline_id > ood_to_baseline_id * self.flip_weight
        flip_detected2 = id_to_baseline_id_cos > ood_to_baseline_id_cos
        flip_detected = flip_detected1 and flip_detected2

        if self.correct_flip:
            if flip_detected:
                print(f"Prototype flip detected:")
                print(f"- ID-Baseline ID Euclidean distance: {id_to_baseline_id:.2f}, OOD-Baseline ID Euclidean distance: {ood_to_baseline_id:.2f}")
                print(f"- ID-Baseline ID cosine distance: {id_to_baseline_id_cos:.2f}, OOD-Baseline ID cosine distance: {ood_to_baseline_id_cos:.2f}")
                
                # flip prototype
                temp_center = self.id_center.copy()
                self.id_center = self.ood_center.copy()
                self.ood_center = temp_center
                
                self.prototype_flips.append(self.processed_batch_count)
                return True
            else:
                print(f"Prototype normal:")
                print(f"- ID-Baseline ID Euclidean distance: {id_to_baseline_id:.2f}, OOD-Baseline ID Euclidean distance: {ood_to_baseline_id:.2f}")
                print(f"- ID-Baseline ID cosine distance: {id_to_baseline_id_cos:.2f}, OOD-Baseline ID cosine distance: {ood_to_baseline_id_cos:.2f}")
                return False

    def _compute_ensemble_scores(self, features, logits, methods=None):
        """
        Calculate ensemble scores
        
        Args:
            features (torch.Tensor): feature vector
            logits (torch.Tensor): logit vector
            methods (list): list of methods to use for score calculation. default is self.init_methods
            
        Returns:
            np.ndarray: ensemble scores (0~1 range)
        """
        if methods is None:
            methods = self.init_methods
        
        # calculate ensemble scores
        scores_list = []
        probs = torch.softmax(logits, dim=1).detach().cpu().numpy()  # pre-calculate for multiple methods
        
        for method in methods:
            if method == 'energy':
                # calculate energy scores (-logsumexp)
                temperature = 1.0
                energy = -torch.logsumexp(logits/temperature, dim=1).detach().cpu().numpy()
                method_scores = -energy  # higher is ID
                print(f"Using energy scores for ensemble")
            elif method == 'msp':
                # calculate maximum probability
                temperature = self.temperature
                probs = torch.softmax(logits/temperature, dim=1).detach().cpu().numpy()
                method_scores = np.max(probs, axis=1)  # higher is ID
                print(f"Using maximum probability scores for ensemble")
            elif method == 'entropy':
                # calculate entropy scores
                epsilon = 1e-10
                entropy = -np.sum(probs * np.log(probs + epsilon), axis=1)
                method_scores = -entropy  # lower entropy = higher confidence = ID
                print(f"Using entropy scores for ensemble")
            else:
                raise ValueError(f"Unknown method: {method}")
            
            # normalize
            min_val, max_val = np.min(method_scores), np.max(method_scores)
            if max_val > min_val:
                normalized_scores = (method_scores - min_val) / (max_val - min_val)
            else:
                normalized_scores = np.ones_like(method_scores) * 0.5
            
            scores_list.append(normalized_scores)
        
        # uniform weighting for ensemble
        scores = np.zeros_like(scores_list[0])
        for s in scores_list:
            scores += s
        scores /= len(scores_list)
        
        print(f"Ensemble completed ({', '.join(methods)} methods used)")
        return scores

    def _filter_iqr_outliers(self, features, iqr_factor=1.5):
        """
        IQR-based outlier detection
        
        Args:
            features (torch.Tensor): feature vector
            
        Returns:
            tuple: (outlier mask, feature vector numpy array)
        """
        # convert feature vector to numpy array
        features_np = features.detach().cpu().numpy()
        
        # calculate batch mean
        batch_mean = np.mean(features_np, axis=0)
        
        # calculate distance to mean for each sample
        distances_to_mean = np.array([np.linalg.norm(vec - batch_mean) for vec in features_np])
        
        # IQR-based outlier detection
        q1, q3 = np.percentile(distances_to_mean, [25, 75])
        iqr = q3 - q1
        outlier_mask = distances_to_mean > (q3 + iqr_factor * iqr)
        
        if np.any(outlier_mask):
            print(f"Detected {np.sum(outlier_mask)} outliers and excluded")
        
        return outlier_mask, features_np

    def __call__(self, features, logits=None, update=True, true_labels=None, scores_flipDet=None):
        """
        Calculate RDS and optionally update centers
        
        Args:
            features (torch.Tensor): feature vector
            logits (torch.Tensor, optional): logit vector (required for initialization)
            update (bool): whether to update centers
            true_labels (np.ndarray, optional): true labels (for evaluation)
            
        Returns:
            tuple: (predictions, scores, AUROC(optional))
        """
        batch_size = features.size(0)
        # check if initialized
        if not self.initialized:
            if logits is None:
                raise ValueError("logits is required for initialization")
            
            # collecting initialization batches
            if self.init_batch_count < self.num_init_batches:
                # save batch data
                self.init_features_buffer.append(features.detach().clone())
                self.init_logits_buffer.append(logits.detach().clone())
                self.init_batch_count += 1
                
                print(f"Collecting initialization data: {self.init_batch_count}/{self.num_init_batches} batches")
                
                # if not enough data to initialize
                if self.init_batch_count < self.num_init_batches:
                    # return temporary MSP-based predictions
                    probs = torch.softmax(logits, dim=1).detach().cpu().numpy()
                    scores = np.max(probs, axis=1)
                    threshold = np.percentile(scores, 50)  # temporary threshold
                    predictions = (scores > threshold).astype(np.int32)
                    
                    auroc = None
                    if true_labels is not None:
                        try:
                            auroc = metrics.roc_auc_score(true_labels, scores)
                        except:
                            auroc = 0.5  # default value when exception occurs
                    
                    return predictions, scores, auroc, 0.5
                
                # if all batches are collected, initialize
                print(f"All initialization data collected. Starting center initialization...")
                
                # merge all feature and logit data
                all_features = torch.cat(self.init_features_buffer, dim=0)
                all_logits = torch.cat(self.init_logits_buffer, dim=0)
                
                # run initialization with merged data
                if scores_flipDet is None:
                    predictions, scores = self.initialize_batch_ensemble(all_features, all_logits)
                else:
                    predictions, scores = self.initialize_batch_ensemble(all_features, all_logits, scores_init=scores_flipDet)
                
                # clear buffer memory
                self.init_features_buffer = []
                self.init_logits_buffer = []
                
                # calculate AUROC
                auroc = None
                if true_labels is not None:
                    # use true_labels of current batch
                    try:
                        auroc = metrics.roc_auc_score(true_labels, scores[-len(true_labels):])
                        if update:
                            self.batch_aurocs.append(auroc)
                    except:
                        auroc = 0.5  # default value when exception occurs
                
                # return predictions and scores of current batch
                current_predictions = predictions[-batch_size:]
                current_scores = scores[-batch_size:]
                
                return current_predictions, current_scores, auroc, 1

        # increment batch count after initialization
        if self.initialized:
            self.processed_batch_count += 1

        # check prototype flip at fixed intervals
        prototype_flipped = False
        if self.initialized and logits is not None and self.processed_batch_count % self.check_interval == 0:
            prototype_flipped = self.check_prototype_flip(features, logits, scores_flipDet=scores_flipDet)
        
        
        # calculate RDS based on centers
        rds, id_distances, ood_distances = self.compute_rds(features, logits)
        
        # determine current threshold
        threshold, separability = find_threshold_otsu(rds, return_separability=True)
        self.threshold_history.append(threshold)
        
        # ID/OOD classification
        predictions = (rds < threshold).astype(np.int32)
        confidence = self.compute_confidence(rds, threshold)
        
        # apply automatic correction mechanism
        correction_applied = False
        scores = -rds  # default score (higher is ID)

        self.corrections_applied.append(correction_applied)
        
        # update centers
        if update: #and self.check_update_condition(features, logits, predictions, confidence):
            self.update_centers(features, logits, predictions, confidence)
            
            self.distance_history['id'].append(id_distances.mean())
            self.distance_history['ood'].append(ood_distances.mean())
            
            id_conf = confidence[predictions == 1].mean() if np.any(predictions == 1) else 0
            ood_conf = confidence[predictions == 0].mean() if np.any(predictions == 0) else 0
            
            self.confidence_history['id'].append(id_conf)
            self.confidence_history['ood'].append(ood_conf)
        
        # performance evaluation (optional)
        auroc = None
        if true_labels is not None:
            auroc = metrics.roc_auc_score(true_labels, scores)
            self.auroc = auroc
            fpr, tpr, _ = metrics.roc_curve(true_labels, scores)
            self.fpr, self.tpr = fpr, tpr
            self.batch_aurocs.append(auroc)
        
        return predictions, scores, auroc, separability
    