
# =============================================
# File: train_audio_resnet34.py
# ---------------------------------------------
# - ResNet-34 (no pretrained), input 1×1000×128 spectrograms
# - BF16, AdamW, 25 epochs, 5 warmup (linear) → cosine decay
# - Effective batch size 512 via gradient accumulation
# - Save last/best checkpoints
# =============================================
from __future__ import annotations
import os, math, random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.models import resnet34

# import dataloaders from this file when split
# from vggsound_dataloaders import AudioCfg, make_audio_loader

@dataclass
class TrainAudioCfg:
    csv_path: str = './csv'
    data_path: str = './data'
    out_dir: str = './ckpts_audio'
    num_classes: int = 309

    epochs: int = 25
    warmup_epochs: int = 5

    batch_size: int = 128              # per-step batch
    effective_batch_size: int = 512    # target effective batch

    lr: float = 1e-3
    weight_decay: float = 1e-4
    num_workers: int = 8
    seed: int = 42
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class ResNet34Audio(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.backbone = resnet34(weights=None)
        w = self.backbone.conv1.weight.data
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            self.backbone.conv1.weight.copy_(w.mean(dim=1, keepdim=True))
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)


class WarmupCosine:
    def __init__(self, opt, base_lr, epochs, warmup_epochs):
        self.opt = opt
        self.base_lr = base_lr
        self.epochs = epochs
        self.warmup = warmup_epochs

    def step_epoch(self, epoch: int):
        if epoch < self.warmup:
            lr = self.base_lr * (epoch + 1) / max(1, self.warmup)
        else:
            prog = (epoch - self.warmup) / max(1, self.epochs - self.warmup)
            lr = 0.5 * self.base_lr * (1 + math.cos(math.pi * prog))
        for g in self.opt.param_groups:
            g['lr'] = lr
        return lr


@torch.no_grad()
def evaluate_audio(model, loader, device) -> float:
    model.eval()
    correct = total = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
            logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(1, total)


def main_audio():
    from __main__ import AudioCfg, make_audio_loader  # local import if kept in one file

    cfg = TrainAudioCfg()
    os.makedirs(cfg.out_dir, exist_ok=True)
    set_seed(cfg.seed)

    train_loader = make_audio_loader(AudioCfg(cfg.csv_path, cfg.data_path, 'train'),
                                     batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
    test_loader  = make_audio_loader(AudioCfg(cfg.csv_path, cfg.data_path, 'test'),
                                     batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

    device = torch.device(cfg.device)
    model = ResNet34Audio(num_classes=cfg.num_classes).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sch = WarmupCosine(opt, cfg.lr, cfg.epochs, cfg.warmup_epochs)
    criterion = nn.CrossEntropyLoss()

    accum_steps = max(1, cfg.effective_batch_size // cfg.batch_size)  # 512/128 = 4

    best = 0.0
    for epoch in range(cfg.epochs):
        model.train()
        lr_now = sch.step_epoch(epoch)
        opt.zero_grad(set_to_none=True)

        seen = correct = total = 0
        for step, (x, y) in enumerate(train_loader, start=1):
            x = x.to(device)
            y = y.to(device)
            with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
                logits = model(x)
                loss = criterion(logits, y) / accum_steps
            loss.backward()

            # metrics on-the-fly
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total += y.numel()
            seen += y.size(0)

            if step % accum_steps == 0:
                opt.step()
                opt.zero_grad(set_to_none=True)

        # flush remaining grads
        if (len(train_loader) % accum_steps) != 0:
            opt.step()
            opt.zero_grad(set_to_none=True)

        train_acc = correct / max(1, total)
        val_acc = evaluate_audio(model, test_loader, device)

        ckpt = {
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'optimizer': opt.state_dict(),
            'config': vars(cfg),
            'val_acc': val_acc,
            'lr': lr_now,
        }
        torch.save(ckpt, os.path.join(cfg.out_dir, 'last.pt'))
        if val_acc >= best:
            best = val_acc
            torch.save(ckpt, os.path.join(cfg.out_dir, 'best.pt'))

        print(f"[AUDIO] epoch {epoch+1}/{cfg.epochs} lr={lr_now:.3e} train_acc={train_acc:.4f} val_acc={val_acc:.4f}")