"""
SPA for CirCor Heart Murmur Detection
Paper-aligned implementation - ALL original interfaces preserved
"""

import sys
sys.path.append('../heart-murmur-detection')
sys.path.append('../heart-murmur-detection/ModelEvaluation')

from evar.common import (sys, np, pd, kwarg_cfg, Path, torch, logging, append_to_csv, RESULT_DIR)
import torchaudio
import fire
import torch.nn.functional as F
import math
import os
from copy import deepcopy
from tqdm import tqdm
from evar.data import create_dataloader
import evar
from lineareval import make_cfg
from finetune import TaskNetwork, finetune_main
from DataProcessing.find_and_load_patient_files import load_patient_data
from DataProcessing.helper_code import load_recordings
from ModelEvaluation.evaluate_model import evaluate_model

# ===================== SPA Core Components (Paper-Aligned) =====================
class SphericalBranch(torch.nn.Module):
    """Spherical branch: Eq. 3"""
    def __init__(self, in_f, out_f, scale=16.0, max_s=40.0):
        super().__init__()
        self.in_features = in_f
        self.out_features = out_f
        self.W = torch.nn.Parameter(torch.Tensor(out_f, in_f))
        torch.nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
        self.s = torch.nn.Parameter(torch.tensor(scale))
        self.max_s = max_s

    def forward(self, x):
        x_norm = F.normalize(x, p=2, dim=1)
        w_norm = F.normalize(self.W, p=2, dim=1)
        s_clamped = torch.clamp(self.s, min=1.0, max=self.max_s)
        return s_clamped * F.linear(x_norm, w_norm)

class GeometricBranch(torch.nn.Module):
    """Geometric branch: Eq. 4-7 with Dynamic Procrustes"""
    def __init__(self, d, K, m=0.99):
        super().__init__()
        self.d, self.K, self.m = d, K, m
        scale = np.sqrt(K / (K - 1))
        M_base = scale * (torch.eye(K) - torch.ones(K, K) / K)
        if d > K:
            M_base = torch.cat([M_base, torch.zeros(d - K, K)], dim=0)
        elif d < K:
            M_base = M_base[:d, :]
        self.register_buffer('M', M_base)
        self.register_buffer('R', torch.eye(d))
        self.register_buffer('P', torch.zeros(K, d))
        self.register_buffer('c', torch.zeros(K))

    def align(self, z, y):
        """Eq. 4 + Eq. 6: Momentum update + SVD Procrustes"""
        if y.dim() > 1:
            y = y.argmax(dim=1)
        if self.training:
            with torch.no_grad():
                z_norm = F.normalize(z, p=2, dim=1)
                for k in y.unique():
                    k = k.long()
                    mask = (y == k)
                    if mask.sum() == 0:
                        continue
                    mu = z_norm[mask].mean(0)
                    if self.c[k] == 0:
                        self.P[k] = mu
                    else:
                        self.P[k] = self.m * self.P[k] + (1 - self.m) * mu
                    self.c[k] += 1
                if torch.all(self.c > 0):
                    try:
                        H = torch.mm(self.M, self.P)
                        U, _, V = torch.svd(H)
                        R_new = torch.mm(U, V.t())
                        self.R.copy_(0.5 * self.R + 0.5 * R_new)
                    except:
                        pass

    def forward(self, z):
        """Eq. 7: Aligned ETF logits"""
        z_rot = torch.mm(z, self.R.t())
        z_rot_norm = F.normalize(z_rot, p=2, dim=1)
        return torch.mm(z_rot_norm, self.M)

class SPALoss(torch.nn.Module):
    """Optional contrastive loss"""
    def __init__(self, margin=1.2, scale=20.0):
        super(SPALoss, self).__init__()
        self.margin = margin
        self.scale = scale

    def forward(self, features, labels):
        features = F.normalize(features, p=2, dim=1)
        sim_matrix = torch.mm(features, features.t())
        label_matrix = labels.unsqueeze(0) == labels.unsqueeze(1)
        pos_mask = label_matrix.fill_diagonal_(False)
        neg_mask = ~label_matrix
        pos_sim = sim_matrix[pos_mask].view(labels.shape[0], -1)
        neg_sim = sim_matrix[neg_mask].view(labels.shape[0], -1)
        pos_min = pos_sim.min(dim=1)[0]
        neg_max = neg_sim.max(dim=1)[0]
        loss = F.relu(neg_max - pos_min + self.margin)
        loss = self.scale * loss.mean()
        return loss

class LabelSmoothingCrossEntropy(torch.nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_classes = pred.size(1)
        log_pred = F.log_softmax(pred, dim=1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing)
        return torch.mean(torch.sum(-true_dist * log_pred, dim=1))

class KnowledgeDistillationLoss(torch.nn.Module):
    """Eq. 8: Self-alignment via KL"""
    def __init__(self, temperature=4.0):
        super(KnowledgeDistillationLoss, self).__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        p_s = F.log_softmax(student_logits / self.T, dim=1)
        p_t = F.softmax(teacher_logits / self.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T ** 2)
        return loss

def calculate_ece(logits, labels, n_bins=15):
    probs = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = torch.zeros(1, device=logits.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece.item()

class SPA_TaskNetwork(torch.nn.Module):
    """SPA framework for heart murmur detection"""
    def __init__(self, cfg, ar, n_classes=3, dim=256, scale=16.0, spa_margin=1.2, spa_scale=20.0, distill_temp=4.0, alpha_distill=0.5, ma_update=True, ma_beta=0.5):
        super().__init__()
        self.ar = ar
        self.dim = dim
        self.n_classes = n_classes
        self.alpha_distill = alpha_distill
        self.ma_update = ma_update
        self.ma_beta = ma_beta
        self.adapter = torch.nn.Sequential(
            torch.nn.Linear(ar.feature_dim, dim),
            torch.nn.LayerNorm(dim),
            torch.nn.ReLU(),
            torch.nn.Linear(dim, dim)
        )
        self.spherical_branch = SphericalBranch(dim, n_classes, scale=scale)
        self.geometric_branch = GeometricBranch(dim, n_classes)
        self.criterion_cls = LabelSmoothingCrossEntropy(smoothing=0.1)
        self.criterion_distill = KnowledgeDistillationLoss(temperature=distill_temp)
        self.criterion_spa = SPALoss(margin=spa_margin, scale=spa_scale)

    def forward(self, x, labels=None, return_features=False, training=True, ma_state_dict=None):
        with torch.no_grad() if not training else torch.enable_grad():
            features = self.ar(x)
            if isinstance(features, dict) and 'features' in features:
                features = features['features']
        if features.dim() == 3:
            features = features.mean(dim=1)
        adapted_features = self.adapter(features)
        logits_spherical = self.spherical_branch(adapted_features)
        if training and labels is not None:
            self.geometric_branch.align(adapted_features, labels)
        logits_geometric = self.geometric_branch(adapted_features)
        if return_features:
            return logits_spherical, logits_geometric, adapted_features
        if training and labels is not None:
            loss_cls = self.criterion_cls(logits_spherical, labels)
            loss_distill = self.criterion_distill(logits_spherical, logits_geometric)
            loss_spa = self.criterion_spa(adapted_features, labels)
            total_loss = loss_cls + self.alpha_distill * loss_distill + loss_spa
            if self.ma_update and ma_state_dict is not None:
                with torch.no_grad():
                    for (name, param), (ma_name, ma_param) in zip(self.named_parameters(), ma_state_dict.items()):
                        if 'ar' in name:
                            param.data = self.ma_beta * param.data + (1 - self.ma_beta) * ma_param.data
            return {
                'logits': logits_spherical,
                'loss': total_loss,
                'loss_cls': loss_cls,
                'loss_distill': loss_distill,
                'loss_spa': loss_spa
            }
        else:
            return logits_spherical

def infer_and_eval_spa(cfg, model, test_root, eval_mode='follow_prior_work'):
    """Inference and evaluation with SPA model"""
    model.eval()
    pids = sorted(list(set([f.stem.split('_')[0] for f in Path(test_root).glob('*.wav')])))
    txt_files = [test_root + pid + '.txt' for pid in pids]
    print('Test file folder:', test_root)
    print('Test files:', pids[:2], txt_files[:2])
    probabilities, wav_probabilities = [], []
    all_logits, all_labels = [], []

    for txt in tqdm(txt_files):
        data = load_patient_data(txt)
        recordings, frequencies = load_recordings(test_root, data, get_frequencies=True)
        recordings = [torch.tensor(r / 32768.).to(torch.float) for r in recordings]
        wavs = [torchaudio.transforms.Resample(f, cfg.sample_rate)(r) for r, f in zip(recordings, frequencies)]
        L = cfg.unit_samples
        logits = []
        for wav in wavs:
            if len(wav) < L:
                wav = torch.nn.functional.pad(wav, (0, L - len(wav)))
            segment_logits = []
            for pos in range(0, len(wav) - L + 1, L):
                segment = wav[pos:pos + L]
                if len(segment) < L:
                    continue
                with torch.no_grad():
                    x = segment.unsqueeze(0)
                    logit = model(x)
                segment_logits.append(logit)
            if segment_logits:
                logits.append(torch.stack(segment_logits).mean(0))
        if logits:
            logits = torch.vstack(logits)
            logits = logits[:, [1, 2, 0]]
            probs = logits.softmax(1).detach().to('cpu')
            wav_probabilities.append(probs)
            probs = logits.mean(0, keepdims=True).softmax(1).detach().to('cpu')[0]
            probabilities.append(probs)
            all_logits.append(logits.mean(0, keepdim=True))
        else:
            probabilities.append(torch.zeros(3))
            wav_probabilities.append(torch.zeros(1, 3))
    probabilities = torch.stack(probabilities)

    def label_decision_rule(wav_probs):
        cidxs = torch.argmax(wav_probs, dim=1)
        PRESENT, UNKNOWN, ABSENT = 0, 1, 2
        if PRESENT in cidxs:
            final_label = PRESENT
        elif UNKNOWN in cidxs:
            final_label = UNKNOWN
        else:
            final_label = ABSENT
        return final_label

    if eval_mode is None or eval_mode == 'follow_prior_work':
        print('Label decision follows: Panah et al.')
        cidxs = torch.tensor([label_decision_rule(wav_probs) for wav_probs in wav_probabilities])
    elif eval_mode == 'normal':
        print('Label decision is: torch.argmax(probabilities, dim=1)')
        cidxs = torch.argmax(probabilities, dim=1)
    else:
        assert False, f'Unknown eval_mode: {eval_mode}'
    labels = torch.nn.functional.one_hot(cidxs, num_classes=3)
    wav_probabilities = [p.numpy() for p in wav_probabilities]
    probabilities = probabilities.numpy()
    labels = labels.numpy()
    results = evaluate_model(test_root, probabilities, labels)
    if all_logits:
        all_logits_tensor = torch.cat(all_logits, dim=0)
        all_probs = F.softmax(all_logits_tensor, dim=1)
        all_preds = torch.argmax(all_probs, dim=1)
        all_labels_tensor = torch.argmax(torch.tensor(labels), dim=1)
        ece = calculate_ece(all_logits_tensor, all_labels_tensor)
    else:
        ece = 0.0
    results = list(results)
    results.append(ece)
    return results, (wav_probabilities, probabilities)

def eval_main_spa(config_file, task, checkpoint, options='', seed=42, lr=None, hidden=(), epochs=None, early_stop_epochs=None, warmup_epochs=None, mixup=None, freq_mask=None, time_mask=None, rrc=None, training_mask=None, batch_size=None, optim='sgd', unit_sec=None, verbose=False, data_path='work', eval_mode=None, save_prob=None, spa_margin=1.2, spa_scale=20.0, distill_temp=4.0, feat_dim=256, scale=16.0, alpha_distill=0.5, ma_update=True, ma_beta=0.5):
    """Evaluation with SPA model - keeps all original parameters"""
    cfg, n_folds, balanced = make_cfg(config_file, task, options, extras={}, abs_unit_sec=unit_sec)
    lr = lr or cfg.ft_lr
    cfg.mixup = mixup if mixup is not None else cfg.mixup
    cfg.ft_early_stop_epochs = early_stop_epochs if early_stop_epochs is not None else cfg.ft_early_stop_epochs
    cfg.warmup_epochs = warmup_epochs if warmup_epochs is not None else cfg.warmup_epochs
    cfg.ft_epochs = epochs or cfg.ft_epochs
    cfg.ft_freq_mask = freq_mask if freq_mask is not None else cfg.ft_freq_mask
    cfg.ft_time_mask = time_mask if time_mask is not None else cfg.ft_time_mask
    cfg.ft_rrc = rrc if rrc is not None else (cfg.ft_rrc if 'ft_rrc' in cfg else False)
    cfg.training_mask = training_mask if training_mask is not None else (cfg.training_mask if 'training_mask' in cfg else 0.0)
    cfg.ft_bs = batch_size or cfg.ft_bs
    cfg.optim = optim
    cfg.unit_sec = unit_sec
    cfg.data_path = data_path
    train_loader, valid_loader, test_loader, multi_label = create_dataloader(cfg, fold=n_folds - 1, seed=seed, batch_size=cfg.ft_bs, always_one_hot=True, balanced_random=balanced)
    print('Classes:', train_loader.dataset.classes)
    cfg.eval_checkpoint = checkpoint
    cfg.runtime_cfg = kwarg_cfg(lr=lr, seed=seed, hidden=hidden, mixup=cfg.mixup, bs=cfg.ft_bs, freq_mask=cfg.ft_freq_mask, time_mask=cfg.ft_time_mask, rrc=cfg.ft_rrc, epochs=cfg.ft_epochs, early_stop_epochs=cfg.ft_early_stop_epochs, n_class=len(train_loader.dataset.classes))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ar = eval('evar.' + cfg.audio_repr)(cfg).to(device)
    if hasattr(train_loader, 'lms_mode') and train_loader.lms_mode:
        ar.precompute_lms(device, train_loader)
    else:
        ar.precompute(device, train_loader)
    n_classes = len(train_loader.dataset.classes)
    task_model = SPA_TaskNetwork(cfg, ar, n_classes=n_classes, dim=feat_dim, scale=scale, spa_margin=spa_margin, spa_scale=spa_scale, distill_temp=distill_temp, alpha_distill=alpha_distill, ma_update=ma_update, ma_beta=ma_beta).to(device)
    task_model_dp = torch.nn.DataParallel(task_model).to(device)
    print('Using checkpoint', checkpoint)
    state_dict = torch.load(checkpoint, map_location=device)
    if 'module' in list(state_dict.keys())[0]:
        task_model_dp.load_state_dict(state_dict)
    else:
        task_model_dp.module.load_state_dict(state_dict)
    task_model_dp.eval()
    circor_no = task[-1]
    stratified_data = f'../heart-murmur-detection/data/stratified_data{circor_no}/test_data/'
    results, probs = infer_and_eval_spa(cfg, task_model_dp, stratified_data, eval_mode=eval_mode)
    (classes, auroc, auprc, auroc_classes, auprc_classes, f_measure, f_measure_classes, accuracy, accuracy_classes, weighted_accuracy, uar, ece) = results
    name = f'{cfg.id}{"" if cfg.weight_file != "" else "/rnd"}-'
    report = f'SPA Finetuning {name} on {task} -> weighted_accuracy: {weighted_accuracy:.5f}, UAR: {uar:.5f}, recall per class: {accuracy_classes}, ECE: {ece:.4f}'
    report += f', best weight: {checkpoint}, config: {cfg}'
    logging.info(report)
    result_df = pd.DataFrame({
        'representation': [cfg.id.split('_')[-2]],
        'task': [task],
        'wacc': [weighted_accuracy],
        'uar': [uar],
        'r_Present': [accuracy_classes[0]],
        'r_Unknown': [accuracy_classes[1]],
        'r_Absent': [accuracy_classes[2]],
        'ece': [ece],
        'weight_file': [cfg.weight_file],
        'run_id': [cfg.id],
        'report': [report],
        'method': ['spa']
    })
    csv_name = {None: 'circor-scores.csv', 'follow_prior_work': 'circor-scores.csv', 'normal': 'circor-scores-wo-rule.csv'}[eval_mode]
    os.makedirs(RESULT_DIR, exist_ok=True)
    append_to_csv(f'{RESULT_DIR}/{csv_name}', result_df)
    if save_prob is not None:
        save_prob_path = Path(save_prob)
        save_prob_path.mkdir(parents=True, exist_ok=True)
        for i, var in zip(['_1', '_2'], probs):
            prob_name = save_prob_path / str(checkpoint).replace('/', '-').replace('.pth', i + '.npy')
            np.save(prob_name, np.array(var, dtype=object))
            print('Probabilities saved as:', prob_name)
    return weighted_accuracy, uar, ece

def finetune_circor_spa(config_file, task, options='', seed=42, lr=None, hidden=(), epochs=None, early_stop_epochs=None, warmup_epochs=None, mixup=None, freq_mask=None, time_mask=None, rrc=None, training_mask=None, batch_size=None, optim='sgd', unit_sec=None, verbose=False, data_path='work', eval_only=None, eval_mode=None, save_prob='probs', spa_margin=1.2, spa_scale=20.0, distill_temp=4.0, feat_dim=256, scale=16.0, alpha_distill=0.5, ma_update=True, ma_beta=0.5):
    """Finetune with SPA algorithm - keeps all original parameters"""
    assert task in [f'circor{n}' for n in range(1, 4)]
    if eval_only is None:
        cfg, n_folds, balanced = make_cfg(config_file, task, options, extras={}, abs_unit_sec=unit_sec)
        lr = lr or cfg.ft_lr
        cfg.mixup = mixup if mixup is not None else cfg.mixup
        cfg.ft_early_stop_epochs = early_stop_epochs if early_stop_epochs is not None else cfg.ft_early_stop_epochs
        cfg.warmup_epochs = warmup_epochs if warmup_epochs is not None else cfg.warmup_epochs
        cfg.ft_epochs = epochs or cfg.ft_epochs
        cfg.ft_freq_mask = freq_mask if freq_mask is not None else cfg.freq_mask
        cfg.ft_time_mask = time_mask if time_mask is not None else cfg.time_mask
        cfg.ft_rrc = rrc if rrc is not None else (cfg.ft_rrc if 'ft_rrc' in cfg else False)
        cfg.training_mask = training_mask if training_mask is not None else (cfg.training_mask if 'training_mask' in cfg else 0.0)
        cfg.ft_bs = batch_size or cfg.ft_bs
        cfg.optim = optim
        cfg.unit_sec = unit_sec
        cfg.data_path = data_path
        train_loader, valid_loader, test_loader, multi_label = create_dataloader(cfg, fold=n_folds - 1, seed=seed, batch_size=cfg.ft_bs, always_one_hot=True, balanced_random=balanced)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        ar = eval('evar.' + cfg.audio_repr)(cfg).to(device)
        if hasattr(train_loader, 'lms_mode') and train_loader.lms_mode:
            ar.precompute_lms(device, train_loader)
        else:
            ar.precompute(device, train_loader)
        n_classes = len(train_loader.dataset.classes)
        model = SPA_TaskNetwork(cfg, ar, n_classes=n_classes, dim=feat_dim, scale=scale, spa_margin=spa_margin, spa_scale=spa_scale, distill_temp=distill_temp, alpha_distill=alpha_distill, ma_update=ma_update, ma_beta=ma_beta).to(device)
        if 'beats' in cfg.audio_repr.lower():
            backbone_lr = lr * 0.1
            spa_lr = lr
            params = [
                {'params': model.ar.parameters(), 'lr': backbone_lr},
                {'params': model.spherical_branch.parameters(), 'lr': spa_lr},
                {'params': model.adapter.parameters(), 'lr': spa_lr},
                {'params': model.geometric_branch.parameters(), 'lr': spa_lr},
            ]
        else:
            params = list(model.parameters())
        if optim == 'adam':
            optimizer = torch.optim.Adam(params, lr=lr, weight_decay=1e-6)
        elif optim == 'adamw':
            optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=1e-6)
        else:
            optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=1e-6)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.ft_epochs)
        best_path = None
        best_score = 0
        best_ece = float('inf')
        model = torch.nn.DataParallel(model).to(device)
        for epoch in range(cfg.ft_epochs):
            model.train()
            total_loss = 0
            total_loss_cls = 0
            total_loss_distill = 0
            total_loss_spa = 0
            for batch_idx, batch in enumerate(tqdm(train_loader)):
                if isinstance(batch, (list, tuple)) and len(batch) == 2:
                    x, y = batch
                else:
                    x, y = batch['waveform'], batch['target']
                x = x.to(device)
                y = y.to(device).argmax(dim=1) if y.dim() > 1 else y
                optimizer.zero_grad()
                ma_state_dict = deepcopy(model.module.state_dict()) if ma_update else None
                outputs = model(x, labels=y, training=True, ma_state_dict=ma_state_dict)
                loss = outputs['loss']
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                total_loss_cls += outputs.get('loss_cls', 0).item()
                total_loss_distill += outputs.get('loss_distill', 0).item()
                total_loss_spa += outputs.get('loss_spa', 0).item()
            scheduler.step()
            model.eval()
            all_probs = []
            all_labels = []
            all_logits = []
            with torch.no_grad():
                for batch in valid_loader:
                    if isinstance(batch, (list, tuple)) and len(batch) == 2:
                        x, y = batch
                    else:
                        x, y = batch['waveform'], batch['target']
                    x = x.to(device)
                    y = y.argmax(dim=1) if y.dim() > 1 else y
                    logits = model(x, return_features=False, training=False)
                    probs = F.softmax(logits, dim=1)
                    all_probs.append(probs.cpu())
                    all_logits.append(logits.cpu())
                    all_labels.append(y.cpu())
            all_probs = torch.cat(all_probs, dim=0)
            all_logits = torch.cat(all_logits, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            preds = all_probs.argmax(dim=1)
            acc = (preds == all_labels).float().mean().item()
            val_ece = calculate_ece(all_logits, all_labels)
            print(f'Epoch {epoch}: Loss={total_loss / len(train_loader):.4f} (Cls:{total_loss_cls / len(train_loader):.4f}, Distill:{total_loss_distill / len(train_loader):.4f}, SPA:{total_loss_spa / len(train_loader):.4f}), Val Acc={acc:.4f}, Val ECE={val_ece:.4f}')
            if acc > best_score or (abs(acc - best_score) < 0.001 and val_ece < best_ece):
                best_score = acc
                best_ece = val_ece
                os.makedirs('checkpoints', exist_ok=True)
                best_path = f'checkpoints/{cfg.id}_spa_best.pth'
                torch.save(model.module.state_dict(), best_path)
                print(f'Best model saved: {best_path} (Acc:{best_score:.4f}, ECE:{best_ece:.4f})')
        print(f'Training completed. Best accuracy: {best_score:.4f}, Best ECE: {best_ece:.4f}')
    else:
        best_path = eval_only
    return eval_main_spa(config_file, task, best_path, options=options, seed=seed, lr=lr, hidden=hidden, epochs=epochs, early_stop_epochs=early_stop_epochs, warmup_epochs=warmup_epochs, mixup=mixup, freq_mask=freq_mask, time_mask=time_mask, rrc=rrc, training_mask=training_mask, batch_size=batch_size, optim=optim, unit_sec=unit_sec, verbose=verbose, data_path=data_path, eval_mode=eval_mode, save_prob=save_prob, spa_margin=spa_margin, spa_scale=spa_scale, distill_temp=distill_temp, feat_dim=feat_dim, scale=scale, alpha_distill=alpha_distill, ma_update=ma_update, ma_beta=ma_beta)

if __name__ == '__main__':
    fire.Fire({
        'finetune_circor': finetune_circor_spa,
        'eval': eval_main_spa
    })