import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from typing import Optional, Any
from tqdm import tqdm
from tqdm.auto import trange
import wandb
import faiss
import math
from collections import defaultdict
from src.mapper.strategy.base import VectorMapper
from src.embeddings.memmap_dataset import MultiMemmapDatasetLoader
import logging
import os
logging.disable(logging.CRITICAL)
os.environ['PYTHONWARNINGS'] = 'ignore'
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.getLogger().setLevel(logging.CRITICAL)
logging.getLogger().disabled = True
class LazyEmbeddingDataset(Dataset):
    def __init__(self, source_embeddings: np.ndarray, target_embeddings: np.ndarray, indices: np.ndarray, return_global_id: bool = False):
        self.source_embeddings = source_embeddings
        self.target_embeddings = target_embeddings
        self.indices = indices
        self.return_global_id = return_global_id
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        source_emb = self.source_embeddings[actual_idx]
        target_emb = self.target_embeddings[actual_idx]
        source_emb_copy = source_emb.copy()
        target_emb_copy = target_emb.copy()
        if self.return_global_id:
            return torch.from_numpy(source_emb_copy).float(), torch.from_numpy(target_emb_copy).float(), torch.tensor(actual_idx, dtype=torch.long)
        else:
            return torch.from_numpy(source_emb_copy).float(), torch.from_numpy(target_emb_copy).float()
class SimpleLinearModel(nn.Module):
    def __init__(
        self, 
        input_dim: int, 
        output_dim: int, 
        hidden_dim: int = 512, 
        layer_num: int = 2,
        activation: str = 'relu',
        dropout: float = 0.1
    ):
        super(SimpleLinearModel, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.layer_num = layer_num
        layers = []
        if layer_num == 1:
            layers.append(nn.Linear(input_dim, output_dim))
        else:
            layers.append(nn.Linear(input_dim, hidden_dim))
            for _ in range(layer_num - 1):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.Linear(hidden_dim, output_dim))
        self.layers = nn.ModuleList(layers)
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'leaky_relu':
            self.activation = nn.LeakyReLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        self._initialize_weights()
    def _initialize_weights(self) -> None:
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:
                x = self.activation(x)
        return x
class SimpleLinearMapper(VectorMapper):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dim: int = 512,
        layer_num: int = 2,
        activation: str = 'relu',
        dropout: float = 0.1,
        device: torch.device = None,
        learning_rate: float = 1e-4,
        num_epochs: int = 50,
        batch_size: int = 4028,
        gradient_clip: float = 1.0,
        weight_decay: float = 1e-5,
        scheduler_patience: int = 5,
        scheduler_factor: float = 0.5,
        early_stopping_patience: int = 10,
        min_delta: float = 1e-6,
        use_local_distill: bool = False,
        local_k: int = 50,
        local_tau: float = 0.1,
        local_weight: float = 0.5,
        faiss_use_float32: bool = True,
        knn_recompute_epochs: int = 0,
        global_weight: float = 0.5,
        use_structure_preserving: bool = False,
        struct_lambda: float = 0.1,
        struct_k: int = 10,
        struct_pair_sampling: str = 'knn',
    ):
        super(SimpleLinearMapper, self).__init__()
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = device
        self.hidden_dim = hidden_dim
        self.layer_num = layer_num
        self.activation = activation
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.gradient_clip = gradient_clip
        self.weight_decay = weight_decay
        self.scheduler_patience = scheduler_patience
        self.scheduler_factor = scheduler_factor
        self.early_stopping_patience = early_stopping_patience
        self.min_delta = min_delta
        self.use_local_distill = use_local_distill
        self.local_k = local_k
        self.local_tau = local_tau
        self.local_weight = local_weight
        self.faiss_use_float32 = faiss_use_float32
        self.knn_recompute_epochs = knn_recompute_epochs
        self.global_weight = global_weight
        self._ref_idx_np = None
        self._teacher_ref = None
        self._faiss_index = None
        self._knn_idx = None
        self._knn_sim = None
        self._student_cache = {}
        self.use_structure_preserving = use_structure_preserving
        self.struct_lambda = struct_lambda
        self.struct_k = struct_k
        self.struct_pair_sampling = struct_pair_sampling
        self._struct_knn_idx = None
        self._struct_knn_dist = None
        self._source_ref = None
        self._cached_source_embeddings = None
        self._cached_target_embeddings = None
        self._cached_reference_indices = None
        self.model = SimpleLinearModel(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_dim=hidden_dim,
            layer_num=layer_num,
            activation=activation,
            dropout=dropout
        ).to(device)
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=learning_rate,
            weight_decay=weight_decay
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            patience=scheduler_patience,
            factor=scheduler_factor
        )
        self.mse_criterion = nn.MSELoss()
        self.cosine_weight = 0.5
        self.distill_temperature = 1.0
        print(f"SimpleLinearMapper (MLP) initialized on device: {device}")
        print(f"MLP Architecture: {input_dim} -> {hidden_dim} (x{layer_num-1}) -> {output_dim}")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters())}")
    def _compute_combined_loss(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        mse_loss = self.mse_criterion(outputs, targets)
        cos_loss = 1 - torch.nn.functional.cosine_similarity(outputs, targets, dim=1).mean()
        return mse_loss + self.cosine_weight * cos_loss
    def _compute_distillation_loss(self, student_outputs: torch.Tensor, teacher_outputs: torch.Tensor) -> torch.Tensor:
        student_log_probs = torch.log_softmax(student_outputs / self.distill_temperature, dim=1)
        teacher_probs = torch.softmax(teacher_outputs / self.distill_temperature, dim=1)
        kl_div = torch.nn.functional.kl_div(
            student_log_probs, 
            teacher_probs, 
            reduction='batchmean'
        )
        return (self.distill_temperature ** 2) * kl_div
    def _build_faiss_index(self, vecs_np: np.ndarray):
        index = faiss.IndexFlatIP(vecs_np.shape[1])
        index.add(vecs_np)
        return index
    @torch.no_grad()
    def _precompute_teacher_knn(self, target_embeddings: np.ndarray, reference_indices: np.ndarray):
        ref = target_embeddings[reference_indices]
        ref = ref.astype(np.float32) if self.faiss_use_float32 else ref
        norms = np.linalg.norm(ref, axis=1, keepdims=True) + 1e-12
        ref = ref / norms
        self._ref_idx_np = np.asarray(reference_indices, dtype=np.int64)
        self._teacher_ref = ref
        self._faiss_index = self._build_faiss_index(ref)
        k = min(self.local_k, ref.shape[0])
        sim, idx_local = self._faiss_index.search(ref, k)
        self._knn_idx = self._ref_idx_np[idx_local]
        self._knn_sim = sim
        self._student_cache.clear()
        print(f"Precomputed teacher KNN: {len(reference_indices)} references, k={k}")
        print(f"KNN similarity range: [{self._knn_sim.min():.4f}, {self._knn_sim.max():.4f}]")
    @torch.no_grad()
    def _precompute_structure_knn(self, source_embeddings: np.ndarray, reference_indices: np.ndarray):
        src_ref = source_embeddings[reference_indices]
        src_ref = src_ref.astype(np.float32) if self.faiss_use_float32 else src_ref
        self._source_ref = src_ref
        index = faiss.IndexFlatL2(src_ref.shape[1])
        index.add(src_ref)
        k = min(self.struct_k, src_ref.shape[0])
        dist, idx_local = index.search(src_ref, k)
        self._struct_knn_idx = self._ref_idx_np[idx_local]
        self._struct_knn_dist = np.sqrt(dist)
        print(f"Precomputed structure KNN: {len(reference_indices)} references, k={k}")
        print(f"KNN distance range: [{self._struct_knn_dist.min():.4f}, {self._struct_knn_dist.max():.4f}]")
    def _map_global_to_refpos(self, global_ids: np.ndarray) -> np.ndarray:
        if not hasattr(self, "_gid2pos"):
            self._gid2pos = {int(g): i for i, g in enumerate(self._ref_idx_np.tolist())}
        return np.asarray([self._gid2pos[int(g)] for g in global_ids], dtype=np.int64)
    def _local_distill_kl(self,
                          anchor_global_ids: np.ndarray,
                          batch_student: torch.Tensor,
                          source_embeddings: np.ndarray) -> torch.Tensor:
        device = batch_student.device
        B, d = batch_student.shape
        k = min(self.local_k, self._knn_idx.shape[1])
        nbr_ids = self._knn_idx[self._map_global_to_refpos(anchor_global_ids)][:, :k]
        uniq_nbr_ids = np.unique(nbr_ids.reshape(-1))
        to_compute = [gid for gid in uniq_nbr_ids if gid not in self._student_cache]
        if to_compute:
            src_batch = torch.from_numpy(source_embeddings[to_compute]).float().to(device)
            with torch.no_grad():
                out = self.model(src_batch)
                out = torch.nn.functional.normalize(out, dim=-1)
            for gid, vec in zip(to_compute, out):
                self._student_cache[gid] = vec.detach()
        S_anchor = torch.nn.functional.normalize(batch_student, dim=-1)
        S_nb = torch.stack([torch.stack([self._student_cache[int(g)]
                                         for g in nbr_ids[i]], dim=0)
                            for i in range(B)], dim=0).to(device)
        s_sim = torch.einsum("bd,bkd->bk", S_anchor, S_nb)
        s_logp = torch.log_softmax(s_sim / self.local_tau, dim=1)
        t_sim = torch.from_numpy(
            self._knn_sim[self._map_global_to_refpos(anchor_global_ids)][:, :k]
        ).to(device)
        t_p = torch.softmax(t_sim / self.local_tau, dim=1)
        kl = torch.sum(t_p * (torch.log(t_p + 1e-9) - s_logp), dim=1).mean()
        return (self.local_tau ** 2) * kl
    def _compute_structure_preserving_loss(self,
                                           anchor_global_ids: np.ndarray,
                                           batch_outputs: torch.Tensor,
                                           batch_source: torch.Tensor) -> torch.Tensor:
        device = batch_outputs.device
        B = batch_outputs.shape[0]
        k = min(self.struct_k, self._struct_knn_idx.shape[1])
        ref_pos = self._map_global_to_refpos(anchor_global_ids)
        nbr_gids = self._struct_knn_idx[ref_pos][:, :k]
        src_dists = self._struct_knn_dist[ref_pos][:, :k]
        src_dists_t = torch.from_numpy(src_dists).float().to(device)
        uniq_nbr_gids = np.unique(nbr_gids.reshape(-1))
        nbr_outputs = {}
        to_compute = [int(gid) for gid in uniq_nbr_gids if int(gid) not in self._student_cache]
        if to_compute:
            nbr_src_embs = []
            for gid in to_compute:
                pos = self._gid2pos[gid]
                nbr_src_embs.append(self._source_ref[pos])
            nbr_src_batch = torch.from_numpy(np.array(nbr_src_embs)).float().to(device)
            with torch.no_grad():
                nbr_out = self.model(nbr_src_batch)
                for gid, vec in zip(to_compute, nbr_out):
                    nbr_outputs[gid] = vec.detach()
        for gid in uniq_nbr_gids:
            gid_int = int(gid)
            if gid_int in self._student_cache:
                nbr_outputs[gid_int] = self._student_cache[gid_int] * \
                    torch.norm(self._student_cache[gid_int]) if torch.norm(self._student_cache[gid_int]) < 0.99 else self._student_cache[gid_int]
        try:
            nbr_out_tensor = torch.stack([
                torch.stack([nbr_outputs[int(gid)] for gid in nbr_gids[i]], dim=0)
                for i in range(B)
            ], dim=0).to(device)
        except KeyError as e:
            print(f"KeyError in structure loss: {e}, falling back to direct computation")
            all_nbr_gids = nbr_gids.reshape(-1)
            all_nbr_pos = [self._gid2pos[int(gid)] for gid in all_nbr_gids]
            all_nbr_src = torch.from_numpy(self._source_ref[all_nbr_pos]).float().to(device)
            with torch.no_grad():
                all_nbr_out = self.model(all_nbr_src)
            nbr_out_tensor = all_nbr_out.view(B, k, -1)
        tgt_dists = torch.norm(
            batch_outputs.unsqueeze(1) - nbr_out_tensor,
            dim=2
        )
        distortion = (tgt_dists - src_dists_t) ** 2
        struct_loss = distortion.mean()
        return struct_loss
    def _create_dataloader(
        self, 
        source_emb: np.ndarray, 
        target_emb: np.ndarray,
        indices: np.ndarray,
        shuffle: bool = True,
        return_global_id: bool = False
    ) -> DataLoader:
        source_tensor = torch.from_numpy(source_emb[indices]).float()
        target_tensor = torch.from_numpy(target_emb[indices]).float()
        if return_global_id:
            global_id_tensor = torch.from_numpy(indices).long()
            dataset = TensorDataset(source_tensor, target_tensor, global_id_tensor)
        else:
            dataset = TensorDataset(source_tensor, target_tensor)
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=0,
            pin_memory=torch.cuda.is_available(),
            drop_last=False,
        )
        return dataloader
    def _create_lazy_dataloader(
        self,
        source_embeddings: np.ndarray,
        target_embeddings: np.ndarray,
        indices: np.ndarray,
        shuffle: bool = True,
        return_global_id: bool = False
    ) -> DataLoader:
        dataset = LazyEmbeddingDataset(source_embeddings, target_embeddings, indices, return_global_id=return_global_id)
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=0,
            pin_memory=torch.cuda.is_available(),
            drop_last=False,
        )
        return dataloader
    def _validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        total_loss = 0.0
        num_batches = 0
        with torch.no_grad():
            for batch_source, batch_target in val_loader:
                batch_source = batch_source.to(self.device)
                batch_target = batch_target.to(self.device)
                outputs = self.model(batch_source)
                loss = self._compute_combined_loss(outputs, batch_target)
                total_loss += loss.item()
                num_batches += 1
        return total_loss / num_batches if num_batches > 0 else float('inf')
    def fit(
        self, 
        source_embeddings: Optional[np.ndarray] = None,
        target_embeddings: Optional[np.ndarray] = None,
        reference_indices: Optional[np.ndarray] = None,
        train_loader: Optional[Any] = None,
        validation_split: float = 0.0,
        global_model: Optional[SimpleLinearModel] = None,
        **kwargs
    ) -> None:
        self._log_training_info(source_embeddings, target_embeddings, reference_indices)
        train_loader = self._prepare_dataloader(
            source_embeddings, target_embeddings, reference_indices, train_loader
        )
        self._setup_preprocessing(source_embeddings, target_embeddings, reference_indices)
        self._train(train_loader, global_model)
    def _prepare_dataloader(
        self,
        source_embeddings: Optional[np.ndarray],
        target_embeddings: Optional[np.ndarray],
        reference_indices: Optional[np.ndarray],
        train_loader: Optional[Any]
    ) -> Any:
        has_embeddings = all(x is not None for x in [source_embeddings, target_embeddings, reference_indices])
        has_loader = train_loader is not None
        if not (has_embeddings or has_loader):
            raise ValueError(
                "Must provide either (source_embeddings, target_embeddings, reference_indices) or train_loader"
            )
        if has_embeddings and has_loader:
            raise ValueError("Cannot provide both embeddings and train_loader. Choose one.")
        if has_loader:
            return train_loader
        n_samples = len(reference_indices)
        need_global_id = self.use_local_distill or self.use_structure_preserving
        create_fn = self._create_lazy_dataloader if n_samples > 1e5 else self._create_dataloader
        return create_fn(
            source_embeddings, target_embeddings, reference_indices,
            shuffle=True, return_global_id=need_global_id
        )
    def _setup_preprocessing(
        self,
        source_embeddings: Optional[np.ndarray],
        target_embeddings: Optional[np.ndarray],
        reference_indices: Optional[np.ndarray]
    ) -> None:
        self._cached_source_embeddings = source_embeddings
        self._cached_target_embeddings = target_embeddings
        self._cached_reference_indices = reference_indices
        if source_embeddings is None:
            print("Dataloader mode: Skipping KNN precomputation")
            return
        if self.use_local_distill:
            print("Precomputing teacher KNN for local distance distillation...")
            self._precompute_teacher_knn(target_embeddings, reference_indices)
        if self.use_structure_preserving:
            print("Precomputing source KNN for structure-preserving loss...")
            if self._ref_idx_np is None:
                self._ref_idx_np = np.asarray(reference_indices, dtype=np.int64)
            self._precompute_structure_knn(source_embeddings, reference_indices)
    def _log_training_info(
        self,
        source_embeddings: Optional[np.ndarray],
        target_embeddings: Optional[np.ndarray],
        reference_indices: Optional[np.ndarray]
    ) -> None:
        print(f"Device: {self.device}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"Model is on device: {next(self.model.parameters()).device}")
        if source_embeddings is not None and target_embeddings is not None and reference_indices is not None:
            print(f"Training with {len(reference_indices)} reference samples")
            print(f"Source shape: {source_embeddings.shape}, Target shape: {target_embeddings.shape}")
        else:
            print(f"Training with pre-built dataloader")
    def _train(
        self,
        train_loader: Any,
        global_model: Optional[SimpleLinearModel]
    ) -> None:
        if global_model is not None:
            self.model.load_state_dict(global_model.model.state_dict())
        self.model.train()
        import time
        for epoch in tqdm[int](range(self.num_epochs), desc="Training epochs"):
            epoch_loss = self._train_one_epoch(train_loader, global_model, epoch)
            if epoch % 10 == 0:
                print(f"Epoch {epoch}, Loss: {epoch_loss:.6f}")
                wandb.log({"loss": epoch_loss, "epoch": epoch})
            self._maybe_recompute_knn(epoch)
        print("Training completed!")
    def _train_one_epoch(
        self, 
        train_loader: Any, 
        global_model: Optional[SimpleLinearModel],
        epoch: int
    ) -> float:
        self.model.train()
        epoch_loss = 0.0
        num_batches = 0
        for ib, batch_data in enumerate(train_loader):
            if self._batch_has_nan(batch_data):
                print("Found nan or inf in batch data, skipping")
                continue
            batch_source, batch_target, batch_gid = self._extract_batch(batch_data)
            if epoch == 0 and ib == 0:
                print(f"First batch - device: {batch_source.device}, shape: {batch_source.shape}")
            loss = self._compute_batch_loss(batch_source, batch_target, batch_gid, global_model)
            if np.isnan(loss.item()):
                print("Batch loss is nan, skipping")
                continue
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item()
            num_batches += 1
        return epoch_loss / num_batches if num_batches > 0 else float('inf')
    def _batch_has_nan(self, batch_data: tuple) -> bool:
        return (torch.isnan(batch_data[0]).any() or torch.isinf(batch_data[0]).any() or
                torch.isnan(batch_data[1]).any() or torch.isinf(batch_data[1]).any())
    def _extract_batch(self, batch_data: tuple) -> tuple:
        need_global_id = self.use_local_distill or self.use_structure_preserving
        if need_global_id:
            batch_source, batch_target, batch_gid = batch_data
            batch_gid = batch_gid.cpu().numpy()
        else:
            batch_source, batch_target = batch_data
            batch_gid = None
        return (batch_source.to(self.device), 
                batch_target.to(self.device), 
                batch_gid)
    def _compute_batch_loss(
        self,
        batch_source: torch.Tensor,
        batch_target: torch.Tensor,
        batch_gid: Optional[np.ndarray],
        global_model: Optional[SimpleLinearModel]
    ) -> torch.Tensor:
        outputs = self.model(batch_source)
        loss = self._compute_combined_loss(outputs, batch_target)
        if global_model is not None and self.global_weight > 0:
            global_model.model.eval()
            with torch.no_grad():
                global_outputs = global_model.model(batch_source)
            distill_loss = self._compute_distillation_loss(outputs, global_outputs.detach())
            loss = (1 - self.global_weight) * loss + self.global_weight * distill_loss
        if self.use_local_distill and batch_gid is not None and self._cached_source_embeddings is not None:
            local_kl = self._local_distill_kl(
                anchor_global_ids=batch_gid,
                batch_student=outputs,
                source_embeddings=self._cached_source_embeddings
            )
            loss = loss + self.local_weight * local_kl
        if self.use_structure_preserving and batch_gid is not None:
            struct_loss = self._compute_structure_preserving_loss(
                anchor_global_ids=batch_gid,
                batch_outputs=outputs,
                batch_source=batch_source
            )
            loss = loss + self.struct_lambda * struct_loss
        return loss
    def _maybe_recompute_knn(self, epoch: int) -> None:
        should_recompute = (
            self.use_local_distill and 
            self.knn_recompute_epochs > 0 and 
            (epoch + 1) % self.knn_recompute_epochs == 0 and 
            self._cached_target_embeddings is not None and 
            self._cached_reference_indices is not None
        )
        if should_recompute:
            print(f"Recomputing teacher KNN at epoch {epoch + 1}")
            self._precompute_teacher_knn(self._cached_target_embeddings, self._cached_reference_indices)
    def transform(self, embeddings: np.ndarray) -> np.ndarray:
        self.model.eval()
        n_samples = embeddings.shape[0]
        batch_size = self.batch_size * 4
        results = []
        with torch.no_grad():
            for i in tqdm(range(0, n_samples, batch_size), desc="Transforming embeddings"):
                end_idx = min(i + batch_size, n_samples)
                batch = embeddings[i:end_idx]
                batch_tensor = torch.from_numpy(batch).float().to(self.device)
                output = self.model(batch_tensor)
                output_np = output.cpu().numpy()
                results.append(output_np)
        return np.concatenate(results, axis=0)
    def save_model(self, path: str) -> None:
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
        }, path)
        print(f"Model saved to {path}")
    def load_model(self, path: str) -> None:
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print(f"Model loaded from {path}")
    def fit_with_loader(self, train_loader: Any) -> None:
        self.model.train()
        for param in self.model.parameters():
            param.requires_grad = True
        for epoch in trange(self.num_epochs):
            for batch_idx, (batch_source, batch_target) in enumerate(train_loader):
                batch_source = batch_source.to(self.device)
                batch_target = batch_target.to(self.device)
                if batch_source.dtype != torch.float32:
                    batch_source = batch_source.float()
                if batch_target.dtype != torch.float32:
                    batch_target = batch_target.float()
                assert not torch.isnan(batch_source).any() and not torch.isinf(batch_source).any(), "Batch source contains NaN or Inf"
                assert not torch.isnan(batch_target).any() and not torch.isinf(batch_target).any(), "Batch target contains NaN or Inf"
                self.optimizer.zero_grad()
                outputs = self.model(batch_source)
                loss = self._compute_combined_loss(outputs, batch_target)
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Warning: Invalid loss (NaN/Inf) at epoch {epoch}, batch {batch_idx}")
                    continue
                if not loss.requires_grad:
                    print(f"Warning: Loss does not require gradients! Loss value: {loss.item()}")
                    print(f"Outputs requires_grad: {outputs.requires_grad}")
                    print(f"Model parameters require_grad: {any(p.requires_grad for p in self.model.parameters())}")
                    continue
                loss.backward()
                self.optimizer.step()
                if batch_idx == 0 or batch_idx % 100 == 0:
                    print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}")
                    wandb.log({
                        "loss": loss.item(), 
                        "epoch": epoch, 
                        "batch": batch_idx,
                        "batch_loss": loss.item()
                    })
        print("Training completed!")
    def fit_multi(self, multi_dataloader) -> None:
        self.fit_with_loader(multi_dataloader)
