import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import itertools
import time
import traceback  # For detailed error printing if needed outside of removed blocks

torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection for debugging
from models.losses import infoNCE, local_infoNCE  # Assuming these are your custom loss functions
import tasks  # Assuming tasks module contains evaluation logic (e.g., eval_forecasting)
from models.basicaug import split_with_nan, centerize_vary_length_series, \
    torch_pad_nan  # Basic augmentation/processing utils
from models.module import AugSem  # Core semantic augmentation module for ProSAR
from models.CoTs_encoder import CoSTEncoder as TSEncoder1  # Encoder from CoST paper
from models.encoder import TSEncoder as TSEncoder2  # Another TSEncoder, possibly from TS2Vec
from models.module import hierarchical_contrastive_loss, cluster_loss  # Other loss functions
from models.module import hierarchical_clustering  # Hierarchical clustering (e.g., FINCH-like)
from sklearn.cluster import KMeans  # For Input-Space Grounding
import os
import matplotlib.pyplot as plt


LAEGE_NUM = 1e7  # A large number, possibly for masking purposes
criterion = nn.CrossEntropyLoss()  # Global criterion, e.g. for cluster_loss or downstream tasks


# PrototypeDecoder: Decodes latent prototypes back to time-domain.
# Corresponds to the "decoding consistency mechanism" in ProSAR. [cite: 7, 28, 32]
class PrototypeDecoder(nn.Module):
    def __init__(self, d_model, output_seq_len, num_prototype_features, args, **kwargs):
        super().__init__()
        self.d_model = d_model  # Dimension of latent representation
        self.output_seq_len = output_seq_len  # Sequence length of output prototype
        self.num_prototype_features = num_prototype_features  # Number of features for prototype (1 for univariate)
        self.args = args  # Configuration arguments

        # Decoder output dimension depends on whether prototypes are multivariate
        is_multivariate_proto = getattr(self.args, 'prototype_multivariate', False)  # Safely get, default to False

        if is_multivariate_proto and self.num_prototype_features > 1:
            self.decoder_fc = nn.Linear(d_model, output_seq_len * self.num_prototype_features)
        else:
            self.decoder_fc = nn.Linear(d_model, output_seq_len)

    def forward(self, latent_centroids, output_len=None):
        actual_output_len = output_len if output_len is not None else self.output_seq_len

        if latent_centroids.ndim == 1:  # Handle single prototype case
            latent_centroids_for_fc = latent_centroids.unsqueeze(0)
        else:
            latent_centroids_for_fc = latent_centroids

        decoded = self.decoder_fc(latent_centroids_for_fc)
        num_prototypes_for_view = latent_centroids_for_fc.size(0)

        is_multivariate_proto = getattr(self.args, 'prototype_multivariate', False)

        if is_multivariate_proto and self.num_prototype_features > 1:
            return decoded.view(num_prototypes_for_view, actual_output_len, self.num_prototype_features)
        else:
            return decoded.view(num_prototypes_for_view, actual_output_len)


class AugProtoCL:  # This class implements the core ProSAR framework
    def __init__(
            self,
            args,  # args object from argparse, assumed to be processed
            device='cuda',
    ):
        self.args = args

        # --- 1. Convert string booleans from argparse to Python booleans ---
        # This ensures that 'True'/'False' strings are actual booleans
        attrs_to_convert_to_bool = [
            'prototype_multivariate', 'learnable_prototypes',
            'use_prototype_decoder', 'update_prototypes_at_start',
            'use_cluster_loss', 'contrast_augmented_views',
            'clustering_verbose',
            'ts_cluster_use_kmeans',
            'hard_mask'  # From AutoTCL args, ensure it's bool if used by ProSAR
        ]
        for attr in attrs_to_convert_to_bool:
            if hasattr(self.args, attr):
                val = getattr(self.args, attr)
                if isinstance(val, str):
                    setattr(self.args, attr, val.lower() == 'true')
                elif not isinstance(val, bool):
                    # If not a string and not a bool (e.g. if argparse type=bool already made it bool, or it's an int)
                    # Attempt conversion to bool. This handles cases where type=bool in argparse
                    # might not yield Python True/False for all string inputs.
                    setattr(self.args, attr, bool(val))

        # --- 2. Get specific configurations from args ---
        # Use getattr for safe access with defaults if an arg might be missing
        self.contrastive_loss_type = getattr(self.args, 'contrastive_loss_type', 'nopool')  # Default to 'nopool'

        # --- 3. Initialize model components ---
        self.device = device
        self.batch_size = self.args.batch_size
        self.max_train_length = self.args.max_train_length if hasattr(self.args, 'max_train_length') else None

        # Initialize backbone encoder network _net (f_theta in paper)
        # ProSAR paper mentions using CoST or TS2Vec encoders as reference. [cite: 139, 159]
        if self.args.backbone == "cost":
            self._net = TSEncoder1(  # CoSTEncoder
                input_dims=self.args.input_dims,
                output_dims=self.args.output_dims,
                kernels=self.args.kernels,
                length=self.args.max_train_length if self.args.max_train_length is not None else 512,
                # CoST requires length
                hidden_dims=self.args.encoder_hidden_dims,
                depth=self.args.encoder_depth,
            ).to(self.device)
        elif self.args.backbone == "ts":  # Assuming this is a TS2Vec-like encoder
            self._net = TSEncoder2(  # models.encoder.TSEncoder
                input_dims=self.args.input_dims,
                output_dims=self.args.output_dims,
                hidden_dims=self.args.encoder_hidden_dims,
                depth=self.args.encoder_depth,
            ).to(self.device)
        else:
            raise ValueError(f"Unknown backbone: {self.args.backbone}")

        # Use SWA (Stochastic Weight Averaging) for the main network
        self.net = torch.optim.swa_utils.AveragedModel(self._net)
        self.net.update_parameters(self._net)  # Initialize SWA model

        # Initialize ProSAR's semantic augmenter (AugSem)
        self.dtw_augmenter = AugSem(args=self.args).to(self.device)  # Pass args to AugSem

        # Initialize time-domain prototype parameters (p_k^t in paper) [cite: 87]
        self.num_prototypes = self.args.num_prototypes
        self.prototype_seq_len = self.args.proto_len
        self.input_num_features = self.args.input_dims

        # Initialize prototype tensor based on whether they are multivariate
        if self.args.prototype_multivariate:  # prototype_multivariate is now bool
            # Prototype shape: (num_prototypes, num_features, sequence_length)
            self.prototypes = nn.Parameter(
                torch.randn(self.num_prototypes, self.input_num_features, self.prototype_seq_len),
                # Initialize randomly
                requires_grad=self.args.learnable_prototypes  # learnable_prototypes is now bool
            ).to(self.device)
        else:
            # Prototype shape: (num_prototypes, sequence_length)
            self.prototypes = nn.Parameter(
                torch.zeros(self.num_prototypes, self.prototype_seq_len),  # Initialize randomly
                requires_grad=self.args.learnable_prototypes
            ).to(self.device)

        if not self.args.learnable_prototypes:  # If prototypes are not learned via gradient, use Xavier init
            torch.nn.init.xavier_uniform_(self.prototypes.data)

        self.latent_prototypes = None  # Latent space prototypes (p_k^l in paper), obtained via clustering [cite: 53]
        self.cluster_results_dict = None  # Stores clustering results
        self.prototype_momentum = self.args.prototype_momentum  # EMA update momentum

        # Initialize prototype decoder (D_psi in paper) if enabled [cite: 122]
        if self.args.use_prototype_decoder:  # use_prototype_decoder is now bool
            decoder_num_prototype_features = self.input_num_features if self.args.prototype_multivariate else 1
            self.prototype_decoder = PrototypeDecoder(
                d_model=self.args.output_dims,  # Encoder output dim is decoder input dim
                output_seq_len=self.prototype_seq_len,
                num_prototype_features=decoder_num_prototype_features,
                args=self.args  # Pass args to PrototypeDecoder
            ).to(self.device)
        else:
            self.prototype_decoder = None

        # Collect parameters for optimization
        params_to_optimize = list(self._net.parameters())
        if self.prototype_decoder is not None and any(p.requires_grad for p in self.prototype_decoder.parameters()):
            params_to_optimize.extend(list(self.prototype_decoder.parameters()))
        if self.args.learnable_prototypes:  # If time-domain prototypes are learnable, add to optimizer
            params_to_optimize.append(self.prototypes)

        self.optimizer = torch.optim.AdamW(params_to_optimize, lr=self.args.lr)

        # Training state variables
        self.n_epochs = 0
        self.n_iters = 0
        self.eval_every_epoch = self.args.eval_every_epoch
        self.eval_start_epoch = self.args.eval_start_epoch

        # Loss functions (CE and BCE might be for specific downstream tasks or internal components)
        self.CE = torch.nn.CrossEntropyLoss()
        self.BCE = torch.nn.BCEWithLogitsLoss()
        # self.cluster_criterion for L_inter is often a contrastive loss.
        # CrossEntropyLoss (global `criterion`) is used in `cluster_loss` if targets are class-like.
        self.c = 0  # Stores a value from clustering (e.g., number of clusters or assignments)


        ##adjust learning rate
        self.lr_scheduler = None
        scheduler_type = getattr(self.args, 'lr_scheduler_type', 'none').lower()
        if scheduler_type == 'cosine':
            t_max_default = self.args.epochs if hasattr(self.args, 'epochs') and self.args.epochs is not None else 100
            t_max = getattr(self.args, 'lr_cosine_t_max', t_max_default)
            eta_min = getattr(self.args, 'lr_cosine_eta_min', 0.0)

            self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=t_max,
                eta_min=eta_min
            )
            print(f"INFO: Initialized CosineAnnealingLR scheduler with T_max={t_max}, eta_min={eta_min}")
        elif scheduler_type == 'step':
            step_size = getattr(self.args, 'lr_step_size', 10)
            gamma = getattr(self.args, 'lr_step_gamma', 0.1)
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=step_size,
                gamma=gamma
            )
            print(f"INFO: Initialized StepLR scheduler with step_size={step_size}, gamma={gamma}")
        elif scheduler_type == 'multistep':
            milestones_default = [int(0.5 * self.args.epochs), int(0.75 * self.args.epochs)] if hasattr(self.args,
                                                                                                        'epochs') and self.args.epochs is not None else [
                30, 80]
            milestones = getattr(self.args, 'lr_multistep_milestones', milestones_default)
            gamma = getattr(self.args, 'lr_multistep_gamma', 0.1)
            self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer,
                milestones=milestones,
                gamma=gamma
            )
            print(f"INFO: Initialized MultiStepLR scheduler with milestones={milestones}, gamma={gamma}")
        elif scheduler_type != 'none':
            print(f"WARNING: Unknown lr_scheduler_type '{scheduler_type}'. No LR scheduler will be used.")


        self.loss_log=[]
        self.intra_loss_log_weighted=[]
        self.inst_loss_log_weighted=[]
        self.proto_loss_log_weighted=[]




    def get_dataloader(self, data_np, shuffle=False, drop_last=False):
        # Converts Numpy data to PyTorch DataLoader, handling long sequences and NaNs.
        processed_data_np = data_np.copy()

        # If data is 3D (N, L, C) and sequence length L exceeds max_train_length
        if self.max_train_length is not None and processed_data_np.ndim == 3 and processed_data_np.shape[
            1] > self.max_train_length:
            sections = processed_data_np.shape[1] // self.max_train_length
            if sections >= 2:  # If can be split into at least 2 segments
                # Split sequence using split_with_nan and concatenate segments along batch dimension (axis=0)
                processed_data_np = np.concatenate(split_with_nan(processed_data_np, sections, axis=1), axis=0)
            # If cannot be split but still too long, randomly crop a segment
            elif processed_data_np.shape[1] > self.max_train_length:
                offset = np.random.randint(0, processed_data_np.shape[1] - self.max_train_length + 1)
                processed_data_np = processed_data_np[:, offset:offset + self.max_train_length]

        # Handle variable-length series and NaNs if present
        if processed_data_np.ndim == 3:
            # Check for temporal steps that are all NaN across features
            temporal_missing = np.isnan(processed_data_np).all(axis=-1).any(axis=0)
            if temporal_missing.any() and (temporal_missing[0] or temporal_missing[-1]):
                # If start/end of series has fully missing steps, centerize
                processed_data_np = centerize_vary_length_series(processed_data_np)
            # Remove invalid samples that are all NaN across all timesteps and features
            processed_data_np = processed_data_np[~np.isnan(processed_data_np).all(axis=2).all(axis=1)]

        processed_data_np = np.nan_to_num(processed_data_np)  # Replace remaining NaNs with 0

        if processed_data_np.shape[0] == 0:
            print("Warning: No valid data after preprocessing.")
            # Return an empty DataLoader
            empty_dataset = TensorDataset(
                torch.empty(0, *processed_data_np.shape[1:] if processed_data_np.ndim > 1 else (0,)))
            return processed_data_np, empty_dataset, DataLoader(empty_dataset, batch_size=self.batch_size)

        dataset = TensorDataset(torch.from_numpy(processed_data_np).to(torch.float))

        # Dynamically adjust batch_size if smaller than dataset size
        actual_batch_size = min(self.batch_size, len(dataset))
        if actual_batch_size == 0 and len(dataset) > 0:  # If calculated as 0 but data exists
            actual_batch_size = len(dataset)

        if actual_batch_size == 0:  # Final fallback if still 0
            loader = DataLoader(dataset, batch_size=1, shuffle=shuffle, drop_last=drop_last)
        else:
            loader = DataLoader(dataset, batch_size=actual_batch_size, shuffle=shuffle, drop_last=drop_last)
        return processed_data_np, dataset, loader





    def get_dtw_augmented_views_and_features(self, x_batch_original_L_V, current_prototypes_for_aug):
        # Generates DTW-based augmented views and their feature representations.
        # This is ProSAR's core "Prototype-Guided Semantic View Generation". [cite: 31, 78]
        # x_batch_original_L_V: Input batch, shape (Batch, Length, Vars/Features)
        # current_prototypes_for_aug: Current time-domain prototypes (p_k^t)
        # if self.n_epochs%10==0:
        #     self.plot_proto(current_prototypes_for_aug)

        # AugSem expects input as (Batch, Vars, Length)
        x_for_aug_B_V_L = x_batch_original_L_V.permute(0, 2, 1)

        # (a) Semantic Segment Identification & (b) Transformation using AugSem [cite: 89, 98]
        # aligned_segments: Identified semantic segments x_S [cite: 95]
        # non_aligned_masked_segments: Non-semantic parts x_N, possibly already transformed [cite: 95]
        aligned_segments_B_V_L, non_aligned_masked_segments_B_V_L = self.dtw_augmenter.aug(
            x_for_aug_B_V_L,
            prototype=current_prototypes_for_aug,  # Guided by time-domain prototypes
            original_prototype_was_multivariate=self.args.prototype_multivariate,
            freq_aug=False  # First call might focus on time-domain alignment and segmentation
        )

        # # #----------------------------------------------
        #
        # if self.args.align_verbose:
        #     aligned_segments_V_L = aligned_segments_B_V_L
        #     if aligned_segments_V_L.numel() > 0:  # ?????(??????)
        #         non_zero_count_aligned = torch.count_nonzero(aligned_segments_V_L).item()
        #         total_elements_aligned = aligned_segments_V_L.numel()
        #         proportion_non_zero_aligned = non_zero_count_aligned / total_elements_aligned
        #         print(
        #             f"aligned_segments_V_L : {proportion_non_zero_aligned:.4f} ({non_zero_count_aligned}/{total_elements_aligned})")
        #     else:
        #         proportion_non_zero_aligned = 0.0
        #
        #     non_aligned_masked_segments_V_L = non_aligned_masked_segments_B_V_L
        #
        #     if non_aligned_masked_segments_V_L.numel() > 0:
        #         non_zero_count_non_aligned = torch.count_nonzero(non_aligned_masked_segments_V_L).item()
        #         total_elements_non_aligned = non_aligned_masked_segments_V_L.numel()
        #         proportion_non_zero_non_aligned = non_zero_count_non_aligned / total_elements_non_aligned
        #         print(
        #             f"non_aligned_masked_segments_V_L: {proportion_non_zero_non_aligned:.4f} ({non_zero_count_non_aligned}/{total_elements_non_aligned})")
        #     else:
        #         proportion_non_zero_non_aligned = 0.0
        #         print("non_aligned_masked_segments_V_L ?????")


        # Further transform semantic segments x_S (e.g., frequency-domain phase compensation, controlled noise) to get x_S' [cite: 99, 102]
        ax_dtw_view1_transformed_S_B_V_L = self.dtw_augmenter.aug(
            aligned_segments_B_V_L,  # Operate only on semantic segments
            prototype=current_prototypes_for_aug,
            original_prototype_was_multivariate=self.args.prototype_multivariate,
            freq_aug=True  # Apply frequency-domain augmentations
        )

        # Construct final augmented view 2: x_S' + x_N' (transformed non-semantic part) [cite: 106]
        # ax_dtw_view1_transformed_S_B_V_L is x_S'
        # non_aligned_masked_segments_B_V_L is assumed to be (or contribute to) x_N'
        ax_dtw_view2_B_V_L = ax_dtw_view1_transformed_S_B_V_L + non_aligned_masked_segments_B_V_L






        # Permute augmented views back to (Batch, Length, Vars) for encoder input
        ax_dtw_view1_B_L_V = ax_dtw_view1_transformed_S_B_V_L.permute(0, 2, 1)
        ax_dtw_view2_B_L_V = ax_dtw_view2_B_V_L.permute(0, 2, 1)

        # Get representations of original and two augmented views (z, z_tilde1, z_tilde2)
        out_original = self._net(x_batch_original_L_V)
        out_dtw_view1 = self._net(ax_dtw_view1_B_L_V)
        out_dtw_view2 = self._net(ax_dtw_view2_B_L_V)
        return out_original, out_dtw_view1, out_dtw_view2

    def _update_prototypes_from_batch_features(self,
                                               current_batch_features_cpu_tensor,
                                               source_time_series_data_cpu_tensor=None,
                                               # Original time series for this batch
                                               is_batch_update=True):  # Flag for batch-level vs global update
        # Implements ProSAR's "Synergistic Dual-Prototype Refinement Strategy". [cite: 32, 79]
        # 1. Update Latent Prototypes (p_k^l) via Clustering. [cite: 109]
        # 2. Update Time-Domain Prototypes (p_k^t) via Decoding Consistency & Input-Space Grounding. [cite: 117, 121, 118]

        is_clustering_verbose = self.args.clustering_verbose  # Already bool

        # Prepare features for clustering (ensure 2D: [num_samples, feature_dim])
        features_np = current_batch_features_cpu_tensor.numpy()
        if features_np.shape[0] == 0: print(
            "Warning (prototype update): Input features are empty, skipping update."); return
        if features_np.ndim == 3:  # If features are [Batch, Time, FeatDim], average over time
            features_np = features_np.mean(axis=1)
        if features_np.ndim != 2 or features_np.shape[0] == 0:
            print(
                f"Warning (prototype update): Feature shape {features_np.shape} not suitable for clustering, skipping update.");
            return

        if features_np.shape[0] < self.args.num_prototypes and features_np.shape[0] > 0:  # pragma: no cover
            if is_clustering_verbose: print(
                f"Debug (prototype update): Feature sample count ({features_np.shape[0]}) < target prototype count ({self.args.num_prototypes}).")

        # 1. Update latent prototypes: Hierarchical clustering on current batch features.
        # ProSAR paper suggests FINCH algorithm. [cite: 110]
        cluster_results_dict_current = None
        num_actual_feat_clusters = 0
        self.latent_prototypes = None  # Reset latent_prototypes

        if is_clustering_verbose: print(
            f"Debug (prototype update): Calling hierarchical_clustering, feature shape: {features_np.shape}")  # pragma: no cover
        # c: cluster assignments, _: cluster members, ... cluster_results_dict_current: dict with centroids etc.
        # current_batch_features_cpu_tensor may include features from original and augmented views. [cite: 110]
        c_val, _, _, _, _, cluster_results_dict_current = hierarchical_clustering(
            features_np, initial_rank=None, distance=self.args.clustering_distance,
            ensure_early_exit=True, verbose=is_clustering_verbose,
            ann_threshold=getattr(self.args, 'ann_threshold', 40000),  # ANN threshold if applicable
            layers=getattr(self.args, 'clustering_layers', 1) + 1  # Layers for FINCH-like algorithm
        )
        self.c = c_val  # Store finest-grain cluster assignments or count

        if cluster_results_dict_current: self.cluster_results_dict = cluster_results_dict_current  # Update global cluster results

        # Extract latent centroids (p_k^l) from specified clustering layer [cite: 114]
        target_layer_idx = self.args.clustering_target_layer
        if cluster_results_dict_current and \
                cluster_results_dict_current.get('centroids') and \
                isinstance(cluster_results_dict_current.get('centroids'), list) and \
                0 <= target_layer_idx < len(cluster_results_dict_current['centroids']) and \
                cluster_results_dict_current['centroids'][target_layer_idx] is not None:

            current_latent_centroids_tensor = cluster_results_dict_current['centroids'][target_layer_idx]
            if not isinstance(current_latent_centroids_tensor, torch.Tensor):
                current_latent_centroids_tensor = torch.from_numpy(current_latent_centroids_tensor)

            if current_latent_centroids_tensor.shape[0] > 0:
                self.latent_prototypes = current_latent_centroids_tensor.to(
                    self.device).float()  # These are the updated p_k^l
                num_actual_feat_clusters = self.latent_prototypes.shape[0]
            else:  # pragma: no cover
                if is_clustering_verbose: print("Debug (prototype update): Target layer centroids are empty.")
        else:  # pragma: no cover
            if is_clustering_verbose: print(
                "Debug (prototype update): Could not extract valid target layer centroids from clustering results.")

        if self.latent_prototypes is None:
            if is_clustering_verbose: print(
                "Warning (prototype update): Latent prototypes not generated after clustering.")  # pragma: no cover

        if is_clustering_verbose or (
                num_actual_feat_clusters == 0 and self.latent_prototypes is not None):  # pragma: no cover
            print(
                f"Debug (prototype update) - Actual number of latent feature clusters: {num_actual_feat_clusters if self.latent_prototypes is not None else 0}")

        # 2. Update time-domain prototypes (self.prototypes), only if not directly learned via gradient
        if not self.args.learnable_prototypes:  # learnable_prototypes is bool
            # Initialize candidate prototypes from different sources
            decoded_prototypes_option = None  # Source 1: Decoder output
            direct_avg_feat_cluster_prototypes_option = None  # Source 2: Avg of time series in feature clusters
            direct_ts_cluster_prototypes_option = None  # Source 3: Centroids from direct TS clustering (ISG)

            # Check if source time series data is needed and available
            needs_source_data_for_feat_avg = getattr(self.args, 'w_direct_avg_feat_source', 0.0) > 1e-6
            needs_source_data_for_ts_clust = getattr(self.args, 'w_direct_ts_cluster_source', 0.0) > 1e-6
            if (needs_source_data_for_feat_avg or needs_source_data_for_ts_clust) and \
                    source_time_series_data_cpu_tensor is None:  # pragma: no cover
                print(
                    f"Warning (prototype update): Raw time series data needed for prototype update sources but not provided. Relevant sources will be skipped.")

            # --- Source 1: Latent-to-Time-Domain Decoding Consistency --- [cite: 121]
            # Use prototype_decoder (D_psi) to decode latent centroids (p_k^l) to time-domain (hat_p_k^t) [cite: 122]
            if getattr(self.args, 'w_decoder_source',
                       0.0) > 1e-6 and self.latent_prototypes is not None and num_actual_feat_clusters > 0:
                if self.prototype_decoder is not None:
                    original_decoder_mode = self.prototype_decoder.training
                    self.prototype_decoder.eval()  # Switch to eval mode
                    with torch.no_grad():
                        decoded_prototypes_option = self.prototype_decoder(self.latent_prototypes,
                                                                           output_len=self.prototype_seq_len).cpu()
                    self.prototype_decoder.train(original_decoder_mode)  # Restore mode
                    if decoded_prototypes_option.numel() == 0: decoded_prototypes_option = None  # pragma: no cover
                else:  # pragma: no cover
                    print("Warning (prototype update): w_decoder_source > 0 but prototype_decoder is not initialized.")

            # --- Source 2: Direct average of time series samples corresponding to feature clusters ---
            if getattr(self.args, 'w_direct_avg_feat_source', 0.0) > 1e-6 and \
                    self.latent_prototypes is not None and num_actual_feat_clusters > 0 and \
                    source_time_series_data_cpu_tensor is not None and \
                    cluster_results_dict_current:  # Need cluster results for assignments

                assignments_target_layer = None
                # Get sample-to-cluster assignments from clustering results
                if cluster_results_dict_current.get('im2cluster') and \
                        isinstance(cluster_results_dict_current['im2cluster'], list) and \
                        0 <= target_layer_idx < len(cluster_results_dict_current['im2cluster']) and \
                        cluster_results_dict_current['im2cluster'][target_layer_idx] is not None:
                    assignments_tensor = cluster_results_dict_current['im2cluster'][target_layer_idx]
                    assignments_target_layer = assignments_tensor.cpu().numpy() if isinstance(assignments_tensor,
                                                                                              torch.Tensor) else assignments_tensor

                if assignments_target_layer is not None:
                    direct_avg_prototypes_list = []
                    original_source_data_N = source_time_series_data_cpu_tensor.shape[
                        0]  # Number of samples in original batch

                    for c_idx in range(num_actual_feat_clusters):  # For each found latent cluster
                        # Find indices of features belonging to current latent cluster c_idx
                        feature_indices_in_cluster = np.where(assignments_target_layer == c_idx)[0]

                        if len(feature_indices_in_cluster) > 0:
                            # Map feature indices back to original time series data indices
                            num_samples_in_features = current_batch_features_cpu_tensor.shape[0]

                            if is_batch_update and num_samples_in_features != original_source_data_N and \
                                    num_samples_in_features > original_source_data_N and \
                                    num_samples_in_features % original_source_data_N == 0:
                                # If features are from concatenated views, map indices modulo original batch size
                                original_sample_indices = [idx % original_source_data_N for idx in
                                                           feature_indices_in_cluster]
                            else:  # Features and original data samples correspond one-to-one
                                original_sample_indices = feature_indices_in_cluster.tolist()

                            # Ensure indices are valid
                            valid_original_indices = [idx for idx in original_sample_indices if
                                                      idx < original_source_data_N]

                            if not valid_original_indices:  # pragma: no cover
                                # If no valid samples, pad with zeros
                                shape_for_zero = (self.input_num_features,
                                                  self.prototype_seq_len) if self.args.prototype_multivariate else (
                                self.prototype_seq_len,)
                                direct_avg_prototypes_list.append(
                                    torch.zeros(shape_for_zero, dtype=source_time_series_data_cpu_tensor.dtype))
                                continue

                            # Get original time series for these samples and compute average
                            cluster_samples_ts = source_time_series_data_cpu_tensor[
                                valid_original_indices]  # [NumValid, L_data, C_data]
                            avg_ts_for_cluster = torch.mean(cluster_samples_ts, dim=0)  # [L_data, C_data]

                            # Adjust shape to match target time-domain prototype shape
                            target_proto_shape_is_multivariate = self.args.prototype_multivariate  # bool
                            target_proto_n_vars = self.input_num_features
                            target_proto_seq_len = self.prototype_seq_len

                            if target_proto_shape_is_multivariate:  # Target: [C_proto, L_proto]
                                if avg_ts_for_cluster.ndim == 2 and avg_ts_for_cluster.shape[
                                    1] == target_proto_n_vars:  # Input [L, C]
                                    avg_ts_for_cluster = avg_ts_for_cluster.permute(1, 0)  # -> [C, L]
                                elif avg_ts_for_cluster.ndim == 1 and target_proto_n_vars == 1:  # Input [L], target [1, L]
                                    avg_ts_for_cluster = avg_ts_for_cluster.unsqueeze(0)
                            else:  # Target: [L_proto]
                                if avg_ts_for_cluster.ndim == 2:  # Input [L, C]
                                    avg_ts_for_cluster = torch.mean(avg_ts_for_cluster,
                                                                    dim=1)  # -> [L] (average over features)

                            # Adjust sequence length
                            current_seq_len_avg = avg_ts_for_cluster.shape[-1]
                            if current_seq_len_avg != target_proto_seq_len:
                                if current_seq_len_avg > target_proto_seq_len:
                                    avg_ts_for_cluster = avg_ts_for_cluster[..., :target_proto_seq_len]
                                else:  # < target_proto_seq_len, need padding
                                    pad_len = target_proto_seq_len - current_seq_len_avg
                                    padding_config = [0, 0] * avg_ts_for_cluster.ndim  # e.g. [0,0,0,0] for 2D
                                    padding_config[1] = pad_len  # Pad right of the last dimension
                                    avg_ts_for_cluster = F.pad(avg_ts_for_cluster, tuple(padding_config))
                            direct_avg_prototypes_list.append(avg_ts_for_cluster)
                        else:  # pragma: no cover
                            # If cluster is empty, pad with zeros
                            shape_for_zero = (
                            self.input_num_features, self.prototype_seq_len) if self.args.prototype_multivariate else (
                            self.prototype_seq_len,)
                            direct_avg_prototypes_list.append(
                                torch.zeros(shape_for_zero, dtype=source_time_series_data_cpu_tensor.dtype))

                    if direct_avg_prototypes_list:
                        direct_avg_feat_cluster_prototypes_option = torch.stack(direct_avg_prototypes_list).cpu()
                    else:  # pragma: no cover
                        if is_clustering_verbose: print(
                            "Debug (prototype update): direct_avg_prototypes_list is empty.")
                else:  # pragma: no cover
                    if is_clustering_verbose: print(
                        f"Debug (prototype update): Could not get sample assignments ('im2cluster') for direct_avg_feat_source.")

            # --- Source 3: Input-Space Grounding (ISG) --- [cite: 118]
            # Directly cluster raw time series data (e.g., KMeans) and use centroids as candidates (c_k^isg) [cite: 119, 120]
            if getattr(self.args, 'w_direct_ts_cluster_source', 0.0) > 1e-6 and \
                    source_time_series_data_cpu_tensor is not None:

                ts_cluster_use_kmeans = self.args.ts_cluster_use_kmeans  # Already bool

                if KMeans is None and ts_cluster_use_kmeans:  # pragma: no cover
                    print(
                        "Warning (prototype update): KMeans not imported, but configured for use. Skipping w_direct_ts_cluster_source.")
                elif ts_cluster_use_kmeans:
                    ts_data_to_cluster_np = source_time_series_data_cpu_tensor.clone().numpy()  # [N, L_data, C_data]
                    N_ts, L_data, C_data = ts_data_to_cluster_np.shape
                    ts_data_flat = ts_data_to_cluster_np.reshape(N_ts, -1)  # [N, L_data * C_data]

                    n_clusters_for_ts = self.num_prototypes  # Target num_prototypes time-domain prototypes
                    ts_centroids_flat_np = None

                    if ts_data_flat.shape[0] < n_clusters_for_ts:  # pragma: no cover
                        if is_clustering_verbose: print(
                            f"Debug (prototype update - ISG): TS sample count ({ts_data_flat.shape[0]}) < target prototype count ({n_clusters_for_ts}). Using existing samples as centroids.")
                        # If not enough samples, use existing samples (may need padding later)
                        ts_centroids_flat_np = ts_data_flat
                    else:
                        n_init_val = 'auto' if hasattr(KMeans(n_clusters=min(2, N_ts), n_init='auto'),
                                                       'n_init') and isinstance(
                            KMeans(n_clusters=min(2, N_ts), n_init='auto').n_init, str) else 10  # Handle n_init='auto'
                        kmeans_ts = KMeans(n_clusters=n_clusters_for_ts, random_state=self.args.seed, n_init=n_init_val,
                                           max_iter=100)
                        kmeans_ts.fit(ts_data_flat)
                        ts_centroids_flat_np = kmeans_ts.cluster_centers_

                    if ts_centroids_flat_np is not None:
                        ts_centroids_flat = torch.from_numpy(
                            ts_centroids_flat_np).float()  # [n_clusters_for_ts or N_ts, L_data*C_data]
                        num_cand_ts_protos = ts_centroids_flat.shape[0]
                        reshaped_ts_centroids_list = []

                        target_proto_is_multi = self.args.prototype_multivariate
                        target_vars_proto = self.input_num_features
                        target_len_proto = self.prototype_seq_len

                        for k_idx in range(num_cand_ts_protos):
                            flat_centroid_k = ts_centroids_flat[k_idx]  # Single flattened centroid
                            # Reshape flattened centroid to [L_data, C_data] or [L_data]
                            reshaped_from_flat_k = flat_centroid_k.reshape(L_data,
                                                                           C_data) if C_data > 1 else flat_centroid_k.reshape(
                                L_data)

                            # Adjust to target prototype shape: [C_proto, L_proto] or [L_proto]
                            if target_proto_is_multi:  # Target: [C_proto, L_proto]
                                if reshaped_from_flat_k.ndim == 2 and reshaped_from_flat_k.shape[
                                    1] == target_vars_proto:  # Input [L_data, C_data(=C_proto)]
                                    temp_reshaped_k = reshaped_from_flat_k.permute(1, 0)  # -> [C_proto, L_data]
                                elif reshaped_from_flat_k.ndim == 1 and target_vars_proto == 1:  # Input [L_data], C_proto=1
                                    temp_reshaped_k = reshaped_from_flat_k.unsqueeze(0)  # -> [1, L_data]
                                else:  # pragma: no cover (Shape mismatch)
                                    if is_clustering_verbose: print(
                                        f"Debug (prototype update-ISG): KMeans centroid {k_idx} shape {reshaped_from_flat_k.shape} incompatible with multivariate target. Zero-padding.")
                                    temp_reshaped_k = torch.zeros((target_vars_proto, L_data),
                                                                  dtype=ts_centroids_flat.dtype)
                            else:  # Target: [L_proto] (univariate)
                                if reshaped_from_flat_k.ndim == 2:  # Input [L_data, C_data]
                                    temp_reshaped_k = torch.mean(reshaped_from_flat_k,
                                                                 dim=1)  # -> [L_data] (average over features)
                                elif reshaped_from_flat_k.ndim == 1:  # Input [L_data]
                                    temp_reshaped_k = reshaped_from_flat_k
                                else:  # pragma: no cover
                                    if is_clustering_verbose: print(
                                        f"Debug (prototype update-ISG): KMeans centroid {k_idx} shape {reshaped_from_flat_k.shape} incompatible with univariate target. Zero-padding.")
                                    temp_reshaped_k = torch.zeros(L_data, dtype=ts_centroids_flat.dtype)

                            # Adjust sequence length L_data -> L_proto
                            current_len_temp_k = temp_reshaped_k.shape[-1]
                            if current_len_temp_k != target_len_proto:
                                if current_len_temp_k > target_len_proto:
                                    temp_reshaped_k = temp_reshaped_k[..., :target_len_proto]
                                else:  # < target_len_proto, need padding
                                    pad_l_k = target_len_proto - current_len_temp_k
                                    padding_config_k = [0, 0] * temp_reshaped_k.ndim
                                    padding_config_k[1] = pad_l_k  # Pad right of last dimension
                                    temp_reshaped_k = F.pad(temp_reshaped_k, tuple(padding_config_k))
                            reshaped_ts_centroids_list.append(temp_reshaped_k)

                        if reshaped_ts_centroids_list:
                            direct_ts_cluster_prototypes_option = torch.stack(reshaped_ts_centroids_list).cpu()
                            # If KMeans produced fewer centroids than num_prototypes (e.g., due to few input samples)
                            # Pad with zeros to reach num_prototypes
                            if direct_ts_cluster_prototypes_option.shape[0] < self.num_prototypes:  # pragma: no cover
                                if is_clustering_verbose: print(
                                    f"Debug (prototype update-ISG): KMeans centroid count {direct_ts_cluster_prototypes_option.shape[0]} < target prototype count {self.num_prototypes}. Will pad with zeros.")
                                pad_needed = self.num_prototypes - direct_ts_cluster_prototypes_option.shape[0]
                                padding_protos = torch.zeros(pad_needed, *direct_ts_cluster_prototypes_option.shape[1:],
                                                             dtype=direct_ts_cluster_prototypes_option.dtype)
                                direct_ts_cluster_prototypes_option = torch.cat(
                                    [direct_ts_cluster_prototypes_option, padding_protos], dim=0)
                        else:  # pragma: no cover
                            if is_clustering_verbose: print(
                                "Debug (prototype update-ISG): reshaped_ts_centroids_list is empty.")
                    else:  # pragma: no cover
                        if is_clustering_verbose: print(
                            "Debug (prototype update-ISG): KMeans failed to produce centroids (ts_centroids_flat_np is None).")

            # --- EMA Update for Time-Domain Prototypes (self.prototypes) --- [cite: 115]
            # Combine candidates from the above sources using EMA to update self.prototypes.
            # This reflects "fusing information from both sources" [cite: 123] (here, up to 3 sources).
            updated_prototypes_count = 0
            w_dec = getattr(self.args, 'w_decoder_source', 0.0)
            w_feat_avg = getattr(self.args, 'w_direct_avg_feat_source', 0.0)
            w_ts_clust = getattr(self.args, 'w_direct_ts_cluster_source', 1.0)  # Default to 1.0 if others are 0

            total_weight_sum_config = w_dec + w_feat_avg + w_ts_clust
            if total_weight_sum_config < 1e-6:  # pragma: no cover
                if is_clustering_verbose: print(
                    "Warning (prototype update): All prototype source weights are near zero. Time-domain prototypes cannot be updated via EMA.")
            else:
                # Normalize weights
                norm_w_dec = w_dec / total_weight_sum_config
                norm_w_feat_avg = w_feat_avg / total_weight_sum_config
                norm_w_ts_clust = w_ts_clust / total_weight_sum_config

                # Iterate through each stored prototype self.prototypes[i] to update it
                for i in range(self.num_prototypes):
                    stored_prototype_i_cpu = self.prototypes.data[i].cpu()  # Current stored prototype
                    accumulated_candidate_i = torch.zeros_like(stored_prototype_i_cpu)  # For weighted sum of candidates
                    current_applied_weight_i = 0.0  # Actual sum of weights applied to this prototype

                    cand_dec_i, cand_feat_avg_i, cand_ts_clust_i = None, None, None

                    # Match candidates from Source 1 (decoder) and Source 2 (feature avg):
                    # These sources produce num_actual_feat_clusters candidates, which might differ from self.num_prototypes.
                    # Match self.prototypes[i] to these candidates (e.g., by min distance).
                    if num_actual_feat_clusters > 0:
                        # Select a reference set for matching (prefer decoder output, then feature avg)
                        matching_ref_for_feat_clusters = None
                        if decoded_prototypes_option is not None and decoded_prototypes_option.shape[
                            0] == num_actual_feat_clusters:
                            matching_ref_for_feat_clusters = decoded_prototypes_option
                        elif direct_avg_feat_cluster_prototypes_option is not None and \
                                direct_avg_feat_cluster_prototypes_option.shape[
                                    0] == num_actual_feat_clusters:  # pragma: no cover
                            matching_ref_for_feat_clusters = direct_avg_feat_cluster_prototypes_option

                        matched_actual_feat_cluster_idx = -1
                        if matching_ref_for_feat_clusters is not None:
                            flat_matching_options = matching_ref_for_feat_clusters.reshape(num_actual_feat_clusters, -1)
                            dist_to_actual_feat_clusters = torch.cdist(stored_prototype_i_cpu.reshape(1, -1),
                                                                       flat_matching_options)
                            matched_actual_feat_cluster_idx = torch.argmin(
                                dist_to_actual_feat_clusters.squeeze(0)).item()

                        if matched_actual_feat_cluster_idx != -1:
                            if norm_w_dec > 1e-6 and decoded_prototypes_option is not None and \
                                    matched_actual_feat_cluster_idx < decoded_prototypes_option.shape[0]:
                                cand_dec_i = decoded_prototypes_option[matched_actual_feat_cluster_idx]

                            if norm_w_feat_avg > 1e-6 and direct_avg_feat_cluster_prototypes_option is not None and \
                                    matched_actual_feat_cluster_idx < direct_avg_feat_cluster_prototypes_option.shape[
                                0]:
                                cand_feat_avg_i = direct_avg_feat_cluster_prototypes_option[
                                    matched_actual_feat_cluster_idx]

                    # Candidate from Source 3 (direct TS clustering):
                    # This source should produce self.num_prototypes candidates, so access by index i.
                    if norm_w_ts_clust > 1e-6 and direct_ts_cluster_prototypes_option is not None and \
                            i < direct_ts_cluster_prototypes_option.shape[0]:
                        cand_ts_clust_i = direct_ts_cluster_prototypes_option[i]

                    # Weighted accumulation of valid candidates
                    if cand_dec_i is not None:
                        accumulated_candidate_i += norm_w_dec * cand_dec_i
                        current_applied_weight_i += norm_w_dec
                    if cand_feat_avg_i is not None:
                        accumulated_candidate_i += norm_w_feat_avg * cand_feat_avg_i
                        current_applied_weight_i += norm_w_feat_avg
                    if cand_ts_clust_i is not None:
                        accumulated_candidate_i += norm_w_ts_clust * cand_ts_clust_i
                        current_applied_weight_i += norm_w_ts_clust

                    # If any valid candidate contributed, perform EMA update
                    if current_applied_weight_i > 1e-6:
                        final_candidate_for_update_i = accumulated_candidate_i / current_applied_weight_i
                        with torch.no_grad():  # EMA update does not require gradient
                            self.prototypes.data[i] = self.prototype_momentum * self.prototypes.data[i] + \
                                                      (1.0 - self.prototype_momentum) * final_candidate_for_update_i.to(
                                self.device)
                        updated_prototypes_count += 1

                if is_clustering_verbose:  # pragma: no cover
                    if updated_prototypes_count > 0:
                        print(
                            f"Debug (prototype update): EMA updated {updated_prototypes_count} time-domain prototypes.")
                    else:
                        print(
                            f"Debug (prototype update): No time-domain prototypes were EMA updated in this iteration (perhaps all source weights zero or no valid candidates).")
        # print("Debug: Exiting prototype update...")



    def fit(self, train_data_np, n_epochs=None, n_iters=None, task_type='pretraining',
            valid_dataset=None, miverbose=None, args=None):  # train_data_np is the numpy array

        if args is None:  # args from __init__ should already be processed
            args = self.args

        # Ensure train_data_np is 3D (N, L, C)
        if train_data_np.ndim == 2 and self.args.input_dims == 1:  # (N,L) univariate case
            train_data_np = np.expand_dims(train_data_np, axis=-1)  # -> (N,L,1)
        elif train_data_np.ndim == 2 and self.args.input_dims > 1 and task_type == 'forecasting':  # (C, L_total) for forecasting
            # Convert (C, L_total) to (1, L_total, C) to treat as a single long multivariate instance
            train_data_np = np.expand_dims(train_data_np.transpose(1, 0), axis=0)
        assert train_data_np.ndim == 3, f"Training data expected to be 3-dimensional (N, L, C), but got shape {train_data_np.shape}"

        # DataLoader for global prototype update (uses original data segments)
        # Should iterate over the entire dataset (or a large subset) without shuffling.
        processed_train_data_for_global_update, _, train_loader_for_global_proto_update = self.get_dataloader(
            train_data_np, shuffle=False, drop_last=False)
        # Main training DataLoader (shuffled, drop_last, segments processed within loop by get_dataloader)
        processed_train_data_for_loop, _, train_loader = self.get_dataloader(train_data_np, shuffle=True,
                                                                             drop_last=True)

        if self.args.update_prototypes_at_start and self.n_iters == 0:  # (Implicit) Perform global update at start [cite: 117]
            # Pass the dataloader that yields original data segments
            self._perform_global_prototype_update(train_loader_for_global_proto_update,
                                                  processed_train_data_for_global_update)

        loss_log = []
        # Lists to store evaluation metrics
        mses, maes, acc_log_placeholder, vx_log_placeholder, vy_log_placeholder = [], [], [], [], []
        do_valid = valid_dataset is not None

        # --- eval_downstream definition ---
        def eval_downstream(is_final_eval=False, s=True):

            if not do_valid: print("??????,?????"); return {}
            original_net_training_state = self._net.training
            original_swa_net_training_state = self.net.training
            original_decoder_training_state = None
            if self.prototype_decoder and hasattr(self.prototype_decoder, 'eval'):
                original_decoder_training_state = self.prototype_decoder.training
                self.prototype_decoder.eval()
            self._net.eval();
            self.net.eval()
            eval_results_for_return = {}

            if task_type == 'forecasting':

                try:
                    is_final_eval = 1
                    if not is_final_eval:
                        valid_dataset_during_train = valid_dataset[0], valid_dataset[1], valid_dataset[2], \
                                                     valid_dataset[3], valid_dataset[4], [valid_dataset[5][-1]], \
                                                     valid_dataset[6]
                        out, eval_res = tasks.eval_forecasting(self, *valid_dataset_during_train)
                    else:
                        if s:
                            out, eval_res = tasks.eval_forecasting(self, *valid_dataset)
                        else:
                            valid_dataset_during_train = valid_dataset[0], valid_dataset[1], valid_dataset[2], \
                                                         valid_dataset[3], valid_dataset[4], [valid_dataset[5][0]], \
                                                         valid_dataset[6]
                            out, eval_res = tasks.eval_forecasting(self, *valid_dataset_during_train)

                    res = eval_res['ours']
                    mse = sum([res[t]['norm']['MSE'] for t in res]) / len(res)
                    mae = sum([res[t]['norm']['MAE'] for t in res]) / len(res)
                    mses.append(mse)
                    maes.append(mae)
                    for key in eval_res['ours']:
                        print(key, eval_res['ours'][key])
                    print("avg.", mse, mae)
                    print("avg. total", mse + mae)
                except Exception as e:
                    print(f"???????????: {e}")
                    import traceback
                    traceback.print_exc()
                    if not mses: mses.append(float('nan'))  # ??????
                    if not maes: maes.append(float('nan'))


            elif task_type == 'classification':
                cls_train_data, cls_train_labels, cls_test_data, cls_test_labels = valid_dataset
                print("--- ???????? ---")
                _, eval_res = tasks.eval_classification(self, cls_train_data, cls_train_labels, cls_test_data,
                                                        cls_test_labels, eval_protocol='svm')
                current_acc = eval_res['acc']
                acc_log_placeholder.append(current_acc)
                eval_results_for_return['acc'] = current_acc
                vx_log_placeholder.append(float('nan'));
                vy_log_placeholder.append(float('nan'))
                print(f"Epoch {self.n_epochs}: ????? = {current_acc:.4f}")
            else:
                print(f"???? '{task_type}' ??????????")
            self._net.train(original_net_training_state)
            self.net.train(original_swa_net_training_state)
            if self.prototype_decoder and original_decoder_training_state is not None:
                self.prototype_decoder.train(original_decoder_training_state)

        # --- End of eval_downstream definition ---

        while True:  # Epoch loop
            if n_epochs is not None and self.n_epochs >= n_epochs: break
            epoch_cum_loss = 0.0;
            epoch_cum_intra_loss_unweighted = 0.0;
            epoch_cum_inter_inst_comp_unweighted = 0.0;
            epoch_cum_inter_proto_comp_unweighted = 0.0;


            n_epoch_iters = 0;
            interrupted = False
            self._net.train()  # Set underlying model to train mode
            if self.prototype_decoder: self.prototype_decoder.train()

            for batch_idx, batch_tuple in enumerate(train_loader):  # Batch loop
                if n_iters is not None and self.n_iters >= n_iters: interrupted = True; break
                x_batch = batch_tuple[0]
                # max_train_length slicing is handled by get_dataloader with drop_last=True
                x_batch = x_batch.to(self.device)
                if x_batch.size(0) == 0: continue  # Skip empty batches

                self.optimizer.zero_grad()
                # Time-domain prototypes (p_k^t) guide augmentation [cite: 87]
                current_p_for_aug = self.prototypes.detach() if not self.args.learnable_prototypes else self.prototypes
                if current_p_for_aug.numel() == 0 or current_p_for_aug.shape[
                    0] != self.num_prototypes:  # pragma: no cover
                    print(
                        f"Warning: Iter {self.n_iters}, time-domain prototypes invalid. Shape: {current_p_for_aug.shape}. Skipping batch.");
                    self.n_iters += 1;
                    continue

                # ProSAR core: Prototype-Guided Semantic View Generation [cite: 31, 87]
                out_original, out_dtw_view1, out_dtw_view2 = \
                    self.get_dtw_augmented_views_and_features(x_batch, current_p_for_aug)

                # Prototype Refinement (batch-level, if frequency allows) [cite: 32, 109]
                if self.n_iters >= 0 and self.n_iters % self.args.prototype_update_freq == 0:
                    original_net_training_state_batch_proto = self._net.training  # Save current _net state
                    if self.prototype_decoder: original_decoder_state_batch_proto = self.prototype_decoder.training; self.prototype_decoder.eval()
                    self._net.eval()  # Encoder should be in eval mode for feature extraction for clustering

                    # Features for clustering can be from original and/or augmented views [cite: 110]
                    current_batch_features_for_proto = torch.cat((out_original, out_dtw_view2),
                                                                 dim=0)  # Example: concat orig and view2
                    self._update_prototypes_from_batch_features(
                        current_batch_features_for_proto.detach().cpu(),
                        x_batch.cpu(),  # Pass original time series for this batch (for ISG, etc.)
                        is_batch_update=True
                    )
                    self._net.train(original_net_training_state_batch_proto)  # Restore _net state
                    if self.prototype_decoder: self.prototype_decoder.train(original_decoder_state_batch_proto)
                # --- End Prototype Refinement ---

                # Calculate Loss: L_total = lambda_intra * L_intra + lambda_inter * L_inter [cite: 132]
                loss_inst = torch.tensor(0.0,
                                         device=x_batch.device)  # L_intra: Intra-instance temporal contrastive loss [cite: 127]

                if self.contrastive_loss_type == 'pool':  # TS2Vec-like hierarchical contrast
                    if self.args.lambda_inst_orig_v1 > 1e-6:
                        loss_inst += self.args.lambda_inst_orig_v1 * hierarchical_contrastive_loss(out_original,
                                                                                                   out_dtw_view1,
                                                                                                   alpha=self.args.hcl_alpha,
                                                                                                   temporal_unit=self.args.hcl_temporal_unit)
                    if self.args.lambda_inst_orig_v2 > 1e-6:
                        loss_inst += self.args.lambda_inst_orig_v2 * hierarchical_contrastive_loss(out_original,
                                                                                                   out_dtw_view2,
                                                                                                   alpha=self.args.hcl_alpha,
                                                                                                   temporal_unit=self.args.hcl_temporal_unit)
                    if self.args.contrast_augmented_views and self.args.lambda_inst_v1_v2 > 1e-6:  # bool from args
                        loss_inst += self.args.lambda_inst_v1_v2 * hierarchical_contrastive_loss(out_dtw_view1,
                                                                                                 out_dtw_view2,
                                                                                                 alpha=self.args.hcl_alpha,
                                                                                                 temporal_unit=self.args.hcl_temporal_unit)
                elif self.contrastive_loss_type == 'nopool':  # InfoNCE + local_infoNCE, closer to ProSAR paper's L_intra [cite: 127]
                    if self.args.lambda_inst_orig_v1 > 1e-6:
                        loss_g1 = infoNCE(out_original, out_dtw_view1, temperature=self.args.temperature_nce,
                                          pooling='max')
                        loss_l1 = local_infoNCE(out_original,
                                                out_dtw_view1)  # local_infoNCE handles its own pooling/cropping
                        loss_inst += self.args.lambda_inst_orig_v1 * (
                                    self.args.global_weight * loss_g1 + self.args.local_weight * loss_l1)
                    if self.args.lambda_inst_orig_v2 > 1e-6:
                        loss_g2 = infoNCE(out_original, out_dtw_view2, temperature=self.args.temperature_nce,
                                          pooling='max')
                        loss_l2 = local_infoNCE(out_original, out_dtw_view2)
                        loss_inst += self.args.lambda_inst_orig_v2 * (
                                    self.args.global_weight * loss_g2 + self.args.local_weight * loss_l2)
                    if self.args.contrast_augmented_views and self.args.lambda_inst_v1_v2 > 1e-6:  # bool from args
                        loss_g12 = infoNCE(out_dtw_view1, out_dtw_view2, temperature=self.args.temperature_nce,
                                           pooling='max')
                        loss_l12 = local_infoNCE(out_dtw_view1, out_dtw_view2)
                        loss_inst += self.args.lambda_inst_v1_v2 * (
                                    self.args.global_weight * loss_g12 + self.args.local_weight * loss_l12)
                else:  # pragma: no cover
                    raise ValueError(f"Unknown contrastive_loss_type: {self.contrastive_loss_type}")

                loss_cluster = torch.tensor(0.0,device=x_batch.device)
                loss_inter_inst_comp = torch.tensor(0.0,device=x_batch.device)
                loss_inter_proto_comp = torch.tensor(0.0,device=x_batch.device)
                if self.args.use_cluster_loss and self.cluster_results_dict and self.c is not None:  # self.c is from clustering [cite: 129]
                    # cluster_loss needs global representations (B,D), so average over time dimension
                    # im_q: anchor (original view), im_k: positive key (an augmented view)
                    # ProSAR's L_inter contrasts instances based on their prototype association. [cite: 129]
                    output_logits, target_labels, output_proto_logits, target_proto_labels = cluster_loss(
                        # Renamed for clarity
                        im_q=out_original.mean(axis=1),  # Global representation for original view
                        im_k=out_dtw_view2.mean(axis=1),  # Global representation for an augmented view
                        cluster_result=self.cluster_results_dict,
                        c=self.c,  # Typically cluster assignments or number of clusters
                        index=torch.arange(out_original.size(0), device=x_batch.device)  # Indices for current batch
                    )
                    # Instance-wise component of L_inter (e.g., like MHCCL's instance-level cluster loss)
                    # `target_labels` should be Long type for CrossEntropyLoss
                    loss_inter_inst_comp = self.args.lambda_cluster_ins*criterion(output_logits.to(self.device,non_blocking=True),
                                         torch.tensor(target_labels).float().to(self.device,non_blocking=True))
                    loss_cluster += loss_inter_inst_comp

                    # Prototype-wise component of L_inter (e.g., like MHCCL's cluster-level prototype loss)
                    if output_proto_logits is not None and target_proto_labels is not None:
                        loss_inter_proto_comp = 0
                        # output_proto_logits and target_proto_labels are lists of tensors/labels for each partition
                        for proto_out_p, proto_target_p in zip(output_proto_logits, target_proto_labels):
                            loss_inter_proto_comp += self.args.lambda_proto*criterion(proto_out_p.to(self.device, non_blocking=True),
                                                        torch.tensor(proto_target_p).float().to(self.device, non_blocking=True))
                        loss_cluster +=  loss_inter_proto_comp

                total_loss = self.args.weight_loss_inst * loss_inst + self.args.weight_loss_clust * loss_cluster
                if torch.isnan(total_loss) or torch.isinf(total_loss):  # pragma: no cover
                    print(
                        f"Warning: Iter {self.n_iters} Loss is NaN/Inf. Instance Loss: {loss_inst.item()}, Cluster Loss: {loss_cluster.item()}")
                else:
                    total_loss.backward()
                    self.optimizer.step()

                self.net.update_parameters(self._net)  # Update SWA model with current _net parameters
                # epoch_cum_loss += total_loss.item() if not (torch.isnan(total_loss) or torch.isinf(total_loss)) else 0.0
                if not (torch.isnan(total_loss) or torch.isinf(total_loss)):
                    epoch_cum_loss += total_loss.item()
                if not (torch.isnan(loss_inst) or torch.isinf(loss_inst)):  # Accumulate unweighted L_intra
                    epoch_cum_intra_loss_unweighted += self.args.weight_loss_inst *loss_inst.item()
                if not (torch.isnan(loss_inter_inst_comp) or torch.isinf(loss_inter_inst_comp)):  # Accumulate unweighted L_inter
                    epoch_cum_inter_inst_comp_unweighted += self.args.weight_loss_clust *loss_inter_inst_comp.item()
                if not (torch.isnan(loss_inter_proto_comp) or torch.isinf(loss_inter_proto_comp)):  # Accumulate unweighted L_inter
                    epoch_cum_inter_proto_comp_unweighted += self.args.weight_loss_clust *loss_inter_proto_comp.item()

                n_epoch_iters += 1;
                self.n_iters += 1
            # --- Batch loop end ---

            self.n_epochs += 1
            # avg_epoch_loss = epoch_cum_loss / n_epoch_iters if n_epoch_iters > 0 else float('nan')
            # loss_log.append(avg_epoch_loss)

            avg_epoch_loss = epoch_cum_loss / n_epoch_iters if n_epoch_iters > 0 else float('nan')
            avg_epoch_intra_loss = epoch_cum_intra_loss_unweighted / n_epoch_iters if n_epoch_iters > 0 else float(
                'nan')
            avg_epoch_inst_loss = epoch_cum_inter_inst_comp_unweighted / n_epoch_iters if n_epoch_iters > 0 else float(
                'nan')
            avg_epoch_proto_loss = epoch_cum_inter_proto_comp_unweighted / n_epoch_iters if n_epoch_iters > 0 else float(
                'nan')



            self.loss_log.append(avg_epoch_loss)
            self.intra_loss_log_weighted.append(avg_epoch_intra_loss)
            self.inst_loss_log_weighted.append(avg_epoch_inst_loss)
            self.proto_loss_log_weighted.append(avg_epoch_proto_loss)


            if self.n_epochs%100==0:

                results_dict = {
                    'total_loss_log': self.loss_log,
                    'intra_loss_log': self.intra_loss_log_weighted,
                    'inter_instance_comp_log': self.inst_loss_log_weighted,
                    'inter_prototype_comp_log': self.proto_loss_log_weighted,
                    'epochs_completed': self.n_epochs
                }

                epochs_trained = results_dict.get('epochs_completed', args.epochs if args.epochs is not None else 0)
                if epochs_trained > 0:
                    plots_output_dir = os.path.join("plots_convergence", args.dataset)

                    required_logs_for_plotting = [
                        'total_loss_log',
                        'intra_loss_log',
                        'inter_instance_comp_log',
                        'inter_prototype_comp_log'
                    ]
                    if all(key in results_dict and results_dict[key] for key in required_logs_for_plotting):
                        self.plot_convergence_losses(
                            loss_logs=results_dict,
                            num_epochs=epochs_trained,
                            dataset_name=args.dataset,
                            output_dir=plots_output_dir
                        )
                    else:
                        print("Warning: Not all required loss logs are available for plotting convergence.")

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
                if 1:
                    current_lr = self.optimizer.param_groups[0]['lr']
                    print(f"Epoch {self.n_epochs}: LR scheduler stepped. Current LR: {current_lr:.6e}")


            if miverbose:  # Control verbosity
                print(
                    f"Epoch: {self.n_epochs}/{n_epochs if n_epochs else 'inf'}, Iter: {self.n_iters}/{n_iters if n_iters else 'inf'}, Avg Loss: {avg_epoch_loss:.4f}")

            # Perform evaluation at specified frequency or at the end of training
            if do_valid and self.n_epochs >= self.eval_start_epoch and \
                    (self.n_epochs % self.eval_every_epoch == 0 or (
                            n_epochs is not None and self.n_epochs == n_epochs)):
                eval_downstream(is_final_eval=False)  # Mid-training evaluation
            if interrupted: break
        # --- Training main loop end ---

        # Final evaluation after all epochs/iterations
        if do_valid:
            print("--- Final Evaluation ---")
            eval_downstream(is_final_eval=True)  # Use all prediction lengths for forecasting

        # Return appropriate metrics based on task type
        if task_type == 'forecasting':
            if not mses: mses.append(float('nan'));  # Ensure lists are not empty if eval failed early
            if not maes: maes.append(float('nan'))
            return mses, maes
        elif task_type == 'classification':
            if not acc_log_placeholder: acc_log_placeholder.append(float('nan'))
            if not vx_log_placeholder: vx_log_placeholder.append(float('nan'))  # Typically unused
            if not vy_log_placeholder: vy_log_placeholder.append(float('nan'))  # Typically unused
            return loss_log, acc_log_placeholder, vx_log_placeholder, vy_log_placeholder
        else:  # Pretraining or other tasks might just return loss log
            return loss_log, [], [], []




    def plot_convergence_losses(self,loss_logs, num_epochs, dataset_name, output_dir="."):

        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)

        epochs_range = range(1, num_epochs + 1)

        plt.figure(figsize=(12, 7))

        colors = plt.cm.get_cmap('tab10', 4)

        if 'total_loss_log' in loss_logs and loss_logs['total_loss_log']:
            plt.plot(epochs_range, loss_logs['total_loss_log'], label=r'Total Loss ($L_{total}$)',
                     color=colors(0), linewidth=2)

        if 'intra_loss_log' in loss_logs and loss_logs['intra_loss_log']:
            plt.plot(epochs_range, loss_logs['intra_loss_log'], label=r'$L_{intra}$ ',
                     color=colors(1), linestyle='--')


        if 'inter_instance_comp_log' in loss_logs and loss_logs['inter_instance_comp_log']:
            plt.plot(epochs_range, loss_logs['inter_instance_comp_log'],
                     label=r'$L_{inter\_inst}$  ', color=colors(2),
                     linestyle=':')

        if 'inter_prototype_comp_log' in loss_logs and loss_logs['inter_prototype_comp_log']:
            plt.plot(epochs_range, loss_logs['inter_prototype_comp_log'],
                     label=r'$L_{inter\_proto}$ ', color=colors(3), linestyle='-.')

        plt.xlabel("Epochs")
        plt.ylabel("Loss Value")
        plt.title(f"ProSAR Training Losses on {dataset_name}")
        plt.legend(loc='upper right')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()

        plot_filename = os.path.join(output_dir, f"convergence_losses_detail_{dataset_name}.png")
        plt.savefig(plot_filename)
        print(f"Detailed convergence plot saved to {plot_filename}")
        plt.close()

    def save_loss_logs_to_file(self, output_dir="."):
        """
        Saves the recorded epoch-wise losses to a text file.
        This method should be called after training is complete (e.g., at the end of fit).
        """
        if not hasattr(self, 'args') or not hasattr(self.args, 'dataset'):
            print("Warning: Cannot save loss logs to file, dataset name not found in args.")
            return

        dataset_name = self.args.dataset
        num_epochs = self.n_epochs  # n_epochs should be updated correctly in fit

        if num_epochs == 0:
            print("Warning: No epochs were trained, skipping saving loss logs to file.")
            return


        logs_output_dir = os.path.join(output_dir, "logs_text", dataset_name)
        if not os.path.exists(logs_output_dir):
            os.makedirs(logs_output_dir, exist_ok=True)

        log_filename = os.path.join(logs_output_dir, f"losses_epoch_wise_{dataset_name}.txt")

        try:
            with open(log_filename, 'w') as f:
                header = "Epoch,Total_Loss,L_intra_unweighted,L_inter_inst_comp,L_inter_proto_comp\n"
                f.write(header)


                total_loss_log = self.loss_log
                intra_loss_log = self.intra_loss_log_weighted
                inter_inst_log = self.inst_loss_log_weighted
                inter_proto_log = self.proto_loss_log_weighted


                for i in range(num_epochs):
                    epoch_num = i + 1
                    tl = total_loss_log[i] if i < len(total_loss_log) else 'N/A'
                    il = intra_loss_log[i] if i < len(intra_loss_log) else 'N/A'
                    iil = inter_inst_log[i] if i < len(inter_inst_log) else 'N/A'
                    ipl = inter_proto_log[i] if i < len(inter_proto_log) else 'N/A'

                    tl_str = f"{tl:.6f}" if isinstance(tl, float) and not np.isnan(tl) else str(tl)
                    il_str = f"{il:.6f}" if isinstance(il, float) and not np.isnan(il) else str(il)
                    iil_str = f"{iil:.6f}" if isinstance(iil, float) and not np.isnan(iil) else str(iil)
                    ipl_str = f"{ipl:.6f}" if isinstance(ipl, float) and not np.isnan(ipl) else str(ipl)

                    f.write(f"{epoch_num},{tl_str},{il_str},{iil_str},{ipl_str}\n")

            print(f"Epoch-wise loss logs saved to {log_filename}")
        except Exception as e:
            print(f"Error saving loss logs to file: {e}")



    def cluster_criterion_requires_long_target(self):  # This method might not be directly used in fit loop anymore
        # Helper to check if the cluster criterion (like CrossEntropy) needs Long type targets
        return isinstance(criterion, nn.CrossEntropyLoss) or \
               isinstance(criterion, nn.NLLLoss)  # criterion is global

    def encode(self, data, mask=None, batch_size=None):
        ''' Compute representations using the model.

        Args:
            data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
            mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'.
            encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size.
            casual (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp.
            sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows.
            batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training.
        Returns:
            repr: The representations for data.
        '''

        assert data.ndim == 3
        if batch_size is None:
            batch_size = self.batch_size
        n_samples, ts_l, _ = data.shape

        org_training = self.net.training
        self.net.eval()

        dataset = TensorDataset(torch.from_numpy(data).to(torch.float))
        loader = DataLoader(dataset, batch_size=batch_size)

        with torch.no_grad():
            output = []
            for batch in loader:
                x = batch[0]
                out = self.net(x.to(self.device, non_blocking=True), mask)
                out = F.max_pool1d(out.transpose(1, 2), kernel_size=out.size(1)).transpose(1, 2).cpu()
                out = out.squeeze(1)

                output.append(out)

            output = torch.cat(output, dim=0)

        self.net.train(org_training)
        return output.numpy()

    def casual_encode(self, data, encoding_window=None, mask=None, sliding_length=None, sliding_padding=0,
                      batch_size=None):
        ''' Compute representations using the model.

        Args:
            data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
            mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'.
            encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size.
            casual (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp.
            sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows.
            batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training.
        Returns:
            repr: The representations for data.
        '''
        casual = True
        if batch_size is None:
            batch_size = self.batch_size
        n_samples, ts_l, _ = data.shape
        org_training = self.net.training
        self.net.eval()

        dataset = TensorDataset(torch.from_numpy(data).to(torch.float))
        loader = DataLoader(dataset, batch_size=batch_size)
        with torch.no_grad():
            output = []
            for batch in loader:
                x = batch[0]
                if sliding_length is not None:
                    reprs = []
                    if n_samples < batch_size:
                        calc_buffer = []
                        calc_buffer_l = 0
                    for i in range(0, ts_l, sliding_length):
                        l = i - sliding_padding
                        r = i + sliding_length + (sliding_padding if not casual else 0)
                        x_sliding = torch_pad_nan(
                            x[:, max(l, 0): min(r, ts_l)],
                            left=-l if l < 0 else 0,
                            right=r - ts_l if r > ts_l else 0,
                            dim=1
                        )
                        if n_samples < batch_size:
                            if calc_buffer_l + n_samples > batch_size:
                                out = self._eval_with_pooling(
                                    torch.cat(calc_buffer, dim=0),
                                    mask,
                                    slicing=slice(sliding_padding, sliding_padding + sliding_length),
                                    encoding_window=encoding_window
                                )
                                reprs += torch.split(out, n_samples)
                                calc_buffer = []
                                calc_buffer_l = 0
                            calc_buffer.append(x_sliding)
                            calc_buffer_l += n_samples
                        else:
                            out = self._eval_with_pooling(
                                x_sliding,
                                mask,
                                slicing=slice(sliding_padding, sliding_padding + sliding_length),
                                encoding_window=encoding_window
                            )
                            reprs.append(out)

                    if n_samples < batch_size:
                        if calc_buffer_l > 0:
                            out = self._eval_with_pooling(
                                torch.cat(calc_buffer, dim=0),
                                mask,
                                slicing=slice(sliding_padding, sliding_padding + sliding_length),
                                encoding_window=encoding_window
                            )
                            reprs += torch.split(out, n_samples)
                            calc_buffer = []
                            calc_buffer_l = 0

                    out = torch.cat(reprs, dim=1)
                    if encoding_window == 'full_series':
                        out = F.max_pool1d(
                            out.transpose(1, 2).contiguous(),
                            kernel_size=out.size(1),
                        ).squeeze(1)
                else:
                    out = self._eval_with_pooling(x, mask, encoding_window=encoding_window)
                    if encoding_window == 'full_series':
                        out = out.squeeze(1)

                output.append(out)

            output = torch.cat(output, dim=0)

        self.net.train(org_training)
        return output.numpy()

    def save(self, fn):
        ''' Save the model to a file.

        Args:
            fn (str): filename.
        '''
        torch.save(self.net.state_dict(), fn)

    def load(self, fn):
        ''' Load the model from a file.

        Args:
            fn (str): filename.
        '''
        state_dict = torch.load(fn, map_location=self.device)
        self.net.load_state_dict(state_dict)

    def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None):
        out = self.net(x.to(self.device, non_blocking=True), mask)
        if encoding_window == 'full_series':
            if slicing is not None:
                out = out[:, slicing]
            out = F.max_pool1d(
                out.transpose(1, 2),
                kernel_size=out.size(1),
            ).transpose(1, 2)

        elif isinstance(encoding_window, int):
            out = F.max_pool1d(
                out.transpose(1, 2),
                kernel_size=encoding_window,
                stride=1,
                padding=encoding_window // 2
            ).transpose(1, 2)
            if encoding_window % 2 == 0:
                out = out[:, :-1]
            if slicing is not None:
                out = out[:, slicing]

        elif encoding_window == 'multiscale':
            p = 0
            reprs = []
            while (1 << p) + 1 < out.size(1):
                t_out = F.max_pool1d(
                    out.transpose(1, 2),
                    kernel_size=(1 << (p + 1)) + 1,
                    stride=1,
                    padding=1 << p
                ).transpose(1, 2)
                if slicing is not None:
                    t_out = t_out[:, slicing]
                reprs.append(t_out)
                p += 1
            out = torch.cat(reprs, dim=-1)

        else:
            if slicing is not None:
                out = out[:, slicing]

        return out.cpu()