"""
SPA (Spherical Procrustes Alignment) for ICBHI Dataset
Paper-aligned implementation - ALL original interfaces preserved
"""

import os
import sys
import json
import random
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from copy import deepcopy

from utils import (
    AverageMeter, accuracy, get_score, set_optimizer,
    adjust_learning_rate, warmup_learning_rate, save_model,
    update_json, SpecAugment, get_backbone_class,
    update_moving_average, calculate_ece
)
from datasets import ICBHIDataset

# ===================== SPA Core Components (Paper-Aligned) =====================
class SphericalClassifier(nn.Module):
    """Spherical branch: Eq. 3 in paper"""
    def __init__(self, in_f, out_f, scale=16.0, max_s=40.0):
        super().__init__()
        self.in_features = in_f
        self.out_features = out_f
        self.W = nn.Parameter(torch.Tensor(out_f, in_f))
        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
        self.s = nn.Parameter(torch.tensor(scale))
        self.max_s = max_s

    def forward(self, x):
        x_norm = F.normalize(x, p=2, dim=1)
        w_norm = F.normalize(self.W, p=2, dim=1)
        s_clamped = torch.clamp(self.s, min=1.0, max=self.max_s)
        return s_clamped * F.linear(x_norm, w_norm)

class GeometricClassifier(nn.Module):
    """Geometric branch: Eq. 4-7 in paper with Dynamic Procrustes Alignment"""
    def __init__(self, d, K, m=0.9):
        super().__init__()
        self.d, self.K = d, K
        self.m = m
        scale = np.sqrt(K / (K - 1))
        M_base = scale * (torch.eye(K) - torch.ones(K, K) / K)
        if d > K:
            M_base = torch.cat([M_base, torch.zeros(d - K, K)], dim=0)
        elif d < K:
            M_base = M_base[:d, :]
        self.register_buffer('M', M_base)
        self.register_buffer('R', torch.eye(d))
        self.register_buffer('P', torch.zeros(K, d))
        self.register_buffer('c', torch.zeros(K))

    def align_prototypes(self, z, y):
        """Eq. 4 (momentum update) + Eq. 6 (SVD-based Procrustes)"""
        if y.dim() > 1:
            y = y.argmax(dim=1)
        if self.training:
            with torch.no_grad():
                z_norm = F.normalize(z, p=2, dim=1)
                for k in y.unique():
                    k = k.long()
                    mask = (y == k)
                    if mask.sum() == 0:
                        continue
                    mu = z_norm[mask].mean(0)
                    if self.c[k] == 0:
                        self.P[k] = mu
                    else:
                        self.P[k] = self.m * self.P[k] + (1 - self.m) * mu
                    self.c[k] += 1
                if torch.all(self.c > 0):
                    try:
                        H = torch.mm(self.M, self.P)
                        U, _, V = torch.svd(H)
                        R_new = torch.mm(U, V.t())
                        self.R.copy_(0.5 * self.R + 0.5 * R_new)
                    except:
                        pass

    def forward(self, z):
        """Eq. 7: logits via aligned ETF"""
        z_rot = torch.mm(z, self.R.t())
        z_rot_norm = F.normalize(z_rot, p=2, dim=1)
        return torch.mm(z_rot_norm, self.M)

class SPALoss(nn.Module):
    """SPA contrastive loss"""
    def __init__(self, margin=1.2, scale=20.0):
        super(SPALoss, self).__init__()
        self.margin = margin
        self.scale = scale

    def forward(self, features, labels):
        features = F.normalize(features, p=2, dim=1)
        sim_matrix = torch.mm(features, features.t())
        label_matrix = labels.unsqueeze(0) == labels.unsqueeze(1)
        pos_mask = label_matrix.fill_diagonal_(False)
        neg_mask = ~label_matrix
        pos_sim = sim_matrix[pos_mask].view(labels.shape[0], -1)
        neg_sim = sim_matrix[neg_mask].view(labels.shape[0], -1)
        pos_min = pos_sim.min(dim=1)[0]
        neg_max = neg_sim.max(dim=1)[0]
        loss = F.relu(neg_max - pos_min + self.margin)
        loss = self.scale * loss.mean()
        return loss

# ===================== Original Loss Functions (Unchanged) =====================
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_classes = pred.size(1)
        log_pred = F.log_softmax(pred, dim=1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing)
        return torch.mean(torch.sum(-true_dist * log_pred, dim=1))

class KnowledgeDistillationLoss(nn.Module):
    """Eq. 8: KL divergence for self-alignment"""
    def __init__(self, temperature=4.0):
        super(KnowledgeDistillationLoss, self).__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        p_s = F.log_softmax(student_logits / self.T, dim=1)
        p_t = F.softmax(teacher_logits / self.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T ** 2)
        return loss

class PAFALoss(nn.Module):
    """PAFA loss (for compatibility)"""
    def __init__(self, temperature=3.0):
        super(PAFALoss, self).__init__()
        self.T = temperature

    def forward(self, features, labels, prototype=None):
        features = F.normalize(features, p=2, dim=1)
        if prototype is None:
            prototype = []
            for c in torch.unique(labels):
                mask = (labels == c)
                proto_c = features[mask].mean(dim=0)
                prototype.append(proto_c)
            prototype = torch.stack(prototype, dim=0)
        sim = torch.mm(features, prototype.t()) / self.T
        log_sim = F.log_softmax(sim, dim=1)
        label_onehot = F.one_hot(labels, num_classes=prototype.shape[0]).float()
        loss = -torch.sum(label_onehot * log_sim, dim=1).mean()
        return loss

# ===================== SPA Framework (Paper-Aligned, Interface-Compatible) =====================
class PAFA_SPA_Framework(nn.Module):
    """Dual-branch SPA framework - maintains original interface"""
    def __init__(self, base, dim, K, layernorm=False, scale=16.0):
        super().__init__()
        self.base = base
        self.use_layernorm = layernorm
        self.dim = dim
        if self.use_layernorm:
            self.ln = nn.LayerNorm(dim)
        self.spherical_classifier = SphericalClassifier(dim, K, scale=scale)
        self.feature_adapter = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        self.geometric_classifier = GeometricClassifier(dim, K)
        self.register_buffer('global_prototype', torch.zeros(K, dim))
        self.register_buffer('proto_count', torch.zeros(K))

    def update_global_prototype(self, features, labels):
        features = F.normalize(features, p=2, dim=1)
        with torch.no_grad():
            for c in torch.unique(labels):
                c = c.long()
                mask = (labels == c)
                if mask.sum() == 0:
                    continue
                mu = features[mask].mean(dim=0)
                if self.proto_count[c] == 0:
                    self.global_prototype[c] = mu
                else:
                    self.global_prototype[c] = 0.9 * self.global_prototype[c] + 0.1 * mu
                self.proto_count[c] += 1

    def forward(self, x, args=None, training=False, ret_feats=False):
        if args is not None and 'beats' not in str(type(self.base)).lower():
            if hasattr(args, 'nospec') and args.nospec:
                raw_feat = self.base(x, args=args, training=training)
            else:
                raw_feat = self.base(x, args=args, training=training)
        else:
            raw_feat = self.base(x, training=training)
        if isinstance(raw_feat, torch.Tensor) and raw_feat.dim() == 3:
            raw_feat = raw_feat.mean(dim=1)
        elif isinstance(raw_feat, dict) and 'features' in raw_feat:
            raw_feat = raw_feat['features']
            if raw_feat.dim() == 3:
                raw_feat = raw_feat.mean(dim=1)
        elif isinstance(raw_feat, (list, tuple)):
            raw_feat = raw_feat[-1]
            if raw_feat.dim() == 3:
                raw_feat = raw_feat.mean(dim=1)
        if self.use_layernorm:
            feat_for_spherical = self.ln(raw_feat)
        else:
            feat_for_spherical = raw_feat
        logits_spherical = self.spherical_classifier(feat_for_spherical)
        z_adapted = self.feature_adapter(raw_feat)
        logits_geometric = self.geometric_classifier(z_adapted)
        if ret_feats:
            return logits_spherical, logits_geometric, feat_for_spherical, z_adapted, raw_feat
        return logits_spherical, logits_geometric

# ===================== Training Functions (Original Interface) =====================
def train(train_loader, model, criterion, optimizer, epoch, args, scaler=None):
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses_cls = AverageMeter()
    losses_sa = AverageMeter()
    losses_pafa = AverageMeter()
    losses_spa = AverageMeter()
    losses_total = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    epoch_logits = []
    epoch_labels = []

    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        images = images.cuda(non_blocking=True)
        class_labels = labels[0].cuda(non_blocking=True)
        bsz = class_labels.shape[0]
        if args.ma_update:
            ma_ckpt = deepcopy(model.state_dict())
        warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer)
        with torch.cuda.amp.autocast():
            logits_spherical, logits_geometric, feat_for_spherical, z_adapted, raw_feat = model(
                images, args=args, training=True, ret_feats=True
            )
            if hasattr(model, 'module'):
                model.module.geometric_classifier.align_prototypes(z_adapted, class_labels)
            else:
                model.geometric_classifier.align_prototypes(z_adapted, class_labels)
            loss_cls = criterion[0](logits_spherical, class_labels)
            loss_sa = criterion[1](logits_spherical, logits_geometric)
            loss_pafa = criterion[2](feat_for_spherical, class_labels,
                                     model.module.global_prototype if hasattr(model, 'module') else model.global_prototype)
            loss_spa = criterion[3](feat_for_spherical, class_labels)
            loss = loss_cls + args.alpha_distill * loss_sa + args.pafa_weight * (loss_pafa + loss_spa)
            epoch_logits.append(logits_spherical.detach())
            epoch_labels.append(class_labels.detach())
        losses_cls.update(loss_cls.item(), bsz)
        losses_sa.update(loss_sa.item(), bsz)
        losses_pafa.update(loss_pafa.item(), bsz)
        losses_spa.update(loss_spa.item(), bsz)
        losses_total.update(loss.item(), bsz)
        [acc1], _ = accuracy(logits_spherical[:bsz], class_labels, topk=(1,))
        top1.update(acc1[0], bsz)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if args.ma_update:
            with torch.no_grad():
                model = update_moving_average(args.ma_beta, model, ma_ckpt)
        if hasattr(model, 'module'):
            model.module.update_global_prototype(feat_for_spherical, class_labels)
        else:
            model.update_global_prototype(feat_for_spherical, class_labels)
        batch_time.update(time.time() - end)
        end = time.time()
        if (idx + 1) % args.print_freq == 0:
            print('Train: [{0}][{1}/{2}]	'
                  'Loss_cls {loss_cls.val:.3f}\t'
                  'Loss_sa {loss_sa.val:.3f}	'
                  'Loss_pafa {loss_pafa.val:.3f}\t'
                  'Loss_spa {loss_spa.val:.3f}	'
                  'Acc@1 {top1.val:.3f}'.format(
                epoch, idx + 1, len(train_loader), loss_cls=losses_cls,
                loss_sa=losses_sa, loss_pafa=losses_pafa, loss_spa=losses_spa, top1=top1))
    epoch_logits = torch.cat(epoch_logits, dim=0)
    epoch_labels = torch.cat(epoch_labels, dim=0)
    ece = calculate_ece(epoch_logits, epoch_labels)
    print(f'Epoch {epoch} Train ECE: {ece:.4f}')
    return losses_total.avg, top1.avg, ece

def validate(val_loader, model, criterion, args, best_acc, best_model=None):
    save_bool = False
    model.eval()
    batch_time = AverageMeter()
    losses_cls = AverageMeter()
    losses_sa = AverageMeter()
    losses_pafa = AverageMeter()
    losses_spa = AverageMeter()
    losses_total = AverageMeter()
    top1 = AverageMeter()
    hits, counts = [0.0] * args.n_cls, [0.0] * args.n_cls
    epoch_logits = []
    epoch_labels = []

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.cuda(non_blocking=True)
            class_labels = labels[0].cuda(non_blocking=True)
            bsz = class_labels.shape[0]
            with torch.cuda.amp.autocast():
                logits_spherical, logits_geometric, feat_for_spherical, _, raw_feat = model(
                    images, args=args, training=False, ret_feats=True
                )
                loss_cls = criterion[0](logits_spherical, class_labels)
                loss_sa = criterion[1](logits_spherical, logits_geometric)
                loss_pafa = criterion[2](feat_for_spherical, class_labels,
                                         model.module.global_prototype if hasattr(model, 'module') else model.global_prototype)
                loss_spa = criterion[3](feat_for_spherical, class_labels)
                loss = loss_cls + args.alpha_distill * loss_sa + args.pafa_weight * (loss_pafa + loss_spa)
            losses_cls.update(loss_cls.item(), bsz)
            losses_sa.update(loss_sa.item(), bsz)
            losses_pafa.update(loss_pafa.item(), bsz)
            losses_spa.update(loss_spa.item(), bsz)
            losses_total.update(loss.item(), bsz)
            [acc1], _ = accuracy(logits_spherical, class_labels, topk=(1,))
            top1.update(acc1[0], bsz)
            epoch_logits.append(logits_spherical.detach())
            epoch_labels.append(class_labels.detach())
            _, preds = torch.max(logits_spherical, 1)
            for i in range(preds.shape[0]):
                counts[class_labels[i].item()] += 1.0
                if not args.two_cls_eval:
                    if preds[i].item() == class_labels[i].item():
                        hits[class_labels[i].item()] += 1.0
                else:
                    if class_labels[i].item() == 0 and preds[i].item() == class_labels[i].item():
                        hits[class_labels[i].item()] += 1.0
                    elif class_labels[i].item() != 0 and preds[i].item() > 0:
                        hits[class_labels[i].item()] += 1.0
            sp, se, sc, f1_normal = get_score(hits, counts)
            batch_time.update(time.time() - end)
            end = time.time()
            if (idx + 1) % args.print_freq == 0:
                print('Test: [{0}/{1}]	'
                      'Acc@1 {top1.val:.3f}\t'
                      'S_p {sp:.3f}	'
                      'S_e {se:.3f}'.format(
                    idx + 1, len(val_loader), top1=top1, sp=sp, se=se))
    epoch_logits = torch.cat(epoch_logits, dim=0)
    epoch_labels = torch.cat(epoch_labels, dim=0)
    ece = calculate_ece(epoch_logits, epoch_labels)
    if sc > best_acc[-2] and se > 0.1:
        save_bool = True
        best_acc = [sp, se, sc, f1_normal, ece]
        best_model = deepcopy(model.state_dict())
    print(' * S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f}'.format(sp, se, sc))
    print(' * Acc@1 {top1.avg:.2f}'.format(top1=top1))
    print(' * Val ECE: {:.4f}'.format(ece))
    return best_acc, best_model, save_bool

def set_loader(args):
    train_transform = []
    val_transform = []

    train_transform.append(transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz))))
    val_transform.append(transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz))))

    train_transform = transforms.Compose(train_transform)
    val_transform = transforms.Compose(val_transform)

    # Pass data_folder parameter
    train_dataset = ICBHIDataset(
        train_flag=True,
        transform=train_transform,
        args=args,
        print_flag=True,
        data_folder=args.data_folder
    )
    val_dataset = ICBHIDataset(
        train_flag=False,
        transform=val_transform,
        args=args,
        print_flag=True,
        data_folder=args.data_folder
    )

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True)
    return train_loader, val_loader, args


def set_model(args):
    kwargs = {}
    if args.model == 'ast':
        kwargs['input_fdim'] = int(args.h * args.resz)
        kwargs['input_tdim'] = int(args.w * args.resz)
        kwargs['label_dim'] = args.n_cls
        kwargs['imagenet_pretrain'] = args.from_sl_official
        kwargs['audioset_pretrain'] = args.audioset_pretrained
        feat_dim = 768
        use_layernorm = False
    elif args.model == 'beats':
        if args.nospec:
            kwargs['spec_transform'] = None
        else:
            kwargs['spec_transform'] = SpecAugment(args)
        feat_dim = 768
        use_layernorm = True
    elif args.model == 'cnn6':
        kwargs['n_mels'] = args.n_mels
        kwargs['pretrained'] = args.from_sl_official
        feat_dim = 768  # Adjust according to actual CNN6 output dimension
        use_layernorm = False

    base_model = get_backbone_class(args.model)(**kwargs)

    if args.model not in ['ast', 'beats'] and args.from_sl_official:
        base_model.load_sl_official_weights()
        print(f'{args.model} pretrained weights loaded')

    if args.pretrained and args.pretrained_ckpt is not None:
        ckpt = torch.load(args.pretrained_ckpt, map_location='cpu')
        state_dict = ckpt['model']
        new_state_dict = {}
        for k, v in state_dict.items():
            if "module." in k:
                k = k.replace("module.", "")
            if "backbone." in k:
                k = k.replace("backbone.", "")
            if not 'mlp_head' in k:
                new_state_dict[k] = v
        base_model.load_state_dict(new_state_dict, strict=False)
        print(f'Pretrained model loaded from: {args.pretrained_ckpt}')

    # Initialize PAFA+SPA framework
    model = PAFA_SPA_Framework(
        base=base_model,
        dim=feat_dim,
        K=args.n_cls,
        layernorm=use_layernorm,
        scale=args.init_scale
    ).cuda()

    # Loss functions (integrate PAFA+SPA)
    criterion_cls = LabelSmoothingCrossEntropy(smoothing=0.1).cuda()
    criterion_self_alignment = KnowledgeDistillationLoss(temperature=args.temp).cuda()
    criterion_pafa = PAFALoss(temperature=args.pafa_temperature).cuda()
    criterion_spa = SPALoss(margin=args.spa_margin, scale=args.spa_scale).cuda()
    criterion = [criterion_cls, criterion_self_alignment, criterion_pafa, criterion_spa]

    # Optimizer (adaptive layer-wise learning rate for BEATS)
    if args.model == 'beats':
        backbone_lr = args.learning_rate * 0.1
        spa_lr = args.learning_rate
        params = [
            {'params': model.base.parameters(), 'lr': backbone_lr},
            {'params': model.spherical_branch.parameters(), 'lr': spa_lr},
            {'params': model.adapter.parameters(), 'lr': spa_lr},
            {'params': model.geometric_branch.parameters(), 'lr': spa_lr},
        ]
    else:
        params = list(model.parameters())
    optimizer = set_optimizer(args, params)

    return model, criterion, optimizer

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='Finetuning with PAFA+SPA ')

    # Basic parameters (Shell script compatible)
    parser.add_argument('--tag', type=str, default='seed1_spa')
    parser.add_argument('--dataset', type=str, default='icbhi')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--class_split', type=str, default='lungsound')
    parser.add_argument('--n_cls', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--desired_length', type=float, default=5.0)
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--cosine', action='store_true', default=True)

    # Model parameters
    parser.add_argument('--model', type=str, default='cnn6', choices=['cnn6', 'ast', 'beats'])
    parser.add_argument('--test_fold', type=str, default='official')
    parser.add_argument('--pad_types', type=str, default='repeat')
    parser.add_argument('--resz', type=float, default=1.0)
    parser.add_argument('--n_mels', type=int, default=128)
    parser.add_argument('--from_sl_official', action='store_true', default=True)
    parser.add_argument('--audioset_pretrained', action='store_true', default=True)
    parser.add_argument('--nospec', action='store_true', default=False)
    parser.add_argument('--h', type=int, default=128)
    parser.add_argument('--w', type=int, default=1000)

    # SPA/SPD parameters
    parser.add_argument('--init_scale', type=float, default=16.0)
    parser.add_argument('--temp', type=float, default=2.0)
    parser.add_argument('--alpha_distill', type=float, default=0.5)
    parser.add_argument('--ma_update', action='store_true', default=True)
    parser.add_argument('--ma_beta', type=float, default=0.5)

    # Training control parameters
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--print_freq', type=int, default=10)
    parser.add_argument('--save_freq', type=int, default=50)
    parser.add_argument('--save_dir', type=str, default='./results_spa')
    parser.add_argument('--save_folder', type=str, default='')  # Dynamically generated
    parser.add_argument('--model_name', type=str, default='cnn6_spa')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--eval', action='store_true', default=False)
    parser.add_argument('--pretrained', action='store_true', default=False)
    parser.add_argument('--pretrained_ckpt', type=str, default=None)
    parser.add_argument('--two_cls_eval', action='store_true', default=False)

    # Data path (Shell script compatible)
    parser.add_argument('--data_folder', type=str, default='/home/xxxx/datasets/ICBHI_final_database')

    # Warmup parameters (Shell script compatible)
    parser.add_argument('--warm', action='store_true', default=False)
    parser.add_argument('--warm_epochs', type=int, default=10)

    # PAFA+SPA specific parameters (passed by Shell script)
    parser.add_argument('--pafa_weight', type=float, default=0.3)
    parser.add_argument('--spa_margin', type=float, default=1.2)
    parser.add_argument('--spa_scale', type=float, default=20.0)
    parser.add_argument('--pafa_temperature', type=float, default=3.0)

    args = parser.parse_args()

    # Dynamically generate save_folder (multi-seed compatible)
    if not args.save_folder:
        args.model_name = f'{args.dataset}_{args.model}_PAFA_SPA'
        args.save_folder = os.path.join(args.save_dir, f"{args.model_name}_seed{args.seed}")
    os.makedirs(args.save_folder, exist_ok=True)

    # Warmup parameter processing
    if args.warm:
        args.warmup_from = args.learning_rate * 0.1
        if args.cosine:
            eta_min = args.learning_rate * (0.1 ** 3)
            args.warmup_to = eta_min + (args.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
        else:
            args.warmup_to = args.learning_rate

    return args

# ===================== Summarize Multi-Seed Results =====================
def update_summary_and_calc_stats(args, best_metrics):
    """Summarize multi-seed results and generate summary_result.json"""
    summary_path = os.path.join(args.save_dir, "summary_result.json")

    # Read existing data
    if os.path.exists(summary_path):
        with open(summary_path, 'r') as f:
            try:
                data = json.load(f)
            except:
                data = {}
    else:
        data = {}

    # Build seed-corresponding Key
    run_key = f"{args.model_name}_seed{args.seed}"
    data[run_key] = best_metrics

    # Write to file
    with open(summary_path, 'w') as f:
        json.dump(data, f, indent=4)

    # Calculate multi-seed statistics
    target_config = args.model_name
    metrics_matrix = []
    for key, val in data.items():
        if key.startswith(target_config) and "_seed" in key and len(val) >= 5:
            metrics_matrix.append(val)

    if len(metrics_matrix) > 0:
        metrics_np = np.array(metrics_matrix)
        means = np.mean(metrics_np, axis=0)
        stds = np.std(metrics_np, axis=0)
        print("=" * 50)
        print(f"Multi-Seed Stats for {target_config}:")
        print(f"Avg Score: {means[2]:.2f} ± {stds[2]:.2f}")
        print(f"Avg ECE: {means[4]:.2f} ± {stds[4]:.2f}")
        print("=" * 50)



# ===================== Main Function (Unchanged) =====================
def main():
    args = parse_args()
    with open(os.path.join(args.save_folder, 'train_args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    best_model = None
    best_acc = [0, 0, 0, 0, 0]
    train_loader, val_loader, args = set_loader(args)
    model, criterion, optimizer = set_model(args)
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f"Loaded checkpoint {args.resume}")
        else:
            print(f"No checkpoint found at {args.resume}")
    else:
        args.start_epoch = 1
    scaler = torch.cuda.amp.GradScaler()
    print('*' * 50)
    print(f'Experiment: {args.model_name} (PAFA+SPA)')
    print('*' * 50)
    if not args.eval:
        print(f'Training for {args.epochs} epochs')
        ece_log = []
        for epoch in range(args.start_epoch, args.epochs + 1):
            adjust_learning_rate(args, optimizer, epoch)
            time1 = time.time()
            loss, acc, train_ece = train(
                train_loader, model, criterion, optimizer, epoch, args, scaler
            )
            time2 = time.time()
            print(f'Epoch {epoch}, time {time2 - time1:.2f}s, acc:{acc:.2f}, ECE:{train_ece:.4f}')
            best_acc, best_model, save_bool = validate(
                val_loader, model, criterion, args, best_acc, best_model
            )
            ece_log.append({
                'epoch': epoch,
                'train_ece': train_ece,
                'val_ece': best_acc[4]
            })
            with open(os.path.join(args.save_folder, 'ece_log.json'), 'w') as f:
                json.dump(ece_log, f, indent=4)
            if save_bool:
                save_file = os.path.join(args.save_folder, f'best_epoch_{epoch}.pth')
                save_model(model, optimizer, args, epoch, save_file)
            if epoch % args.save_freq == 0:
                save_file = os.path.join(args.save_folder, f'epoch_{epoch}.pth')
                save_model(model, optimizer, args, epoch, save_file)
        save_file = os.path.join(args.save_folder, 'best.pth')
        model.load_state_dict(best_model)
        save_model(model, optimizer, args, args.epochs, save_file)
    else:
        best_acc, _, _ = validate(val_loader, model, criterion, args, best_acc, best_model)
    update_summary_and_calc_stats(args, best_acc)
    update_json(args.model_name, best_acc, path=os.path.join(args.save_dir, 'results_spa.json'))
    print(f'Experiment finished (Best ECE: {best_acc[4]:.4f})')

if __name__ == '__main__':
    main()
