#!/usr/bin/env python3
"""
Fine-tune EnCodec (wmar-style) — now with token-match evaluation and wandb logging.
"""
import os
from pathlib import Path
from copy import deepcopy
import argparse
import json
import time
import datetime

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

from transformers import EncodecModel, AutoFeatureExtractor

# Reuse wmar dataloader & augmenter
from training.dataloader import AudioDataset, get_audio_dataloader
from training.augmenter import Augmenter

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Try to import losses
try:
    from training.losses import (
        SISNR,
        LogSTFTMagnitudeLoss,
        MRSTFTLoss,
        SpectralConvergenceLoss,
        STFTLoss,
        MelSpectrogramL1Loss,
        MultiScaleMelSpectrogramLoss,
        TFLoudnessRatio
    )
except ImportError:
    print("Warning: Could not import advanced losses from training.losses. Some loss types may fail.")
    SISNR = MRSTFTLoss = MultiScaleMelSpectrogramLoss = STFTLoss = TFLoudnessRatio = None

# Detect if we are the main process (rank 0 or non-distributed)
global_rank = int(os.environ.get("RANK", -1))
is_main_process = global_rank in [-1, 0]

def get_audio_loss(loss_type, sample_rate=24000):
    if loss_type == "mse":
        return torch.nn.MSELoss()
    elif loss_type == "l1":
        return torch.nn.L1Loss()
    elif loss_type == "sisnr" and SISNR:
        return SISNR(sample_rate=sample_rate)
    elif loss_type == "multi_mel" and MultiScaleMelSpectrogramLoss:
        return MultiScaleMelSpectrogramLoss(sample_rate=sample_rate)
    elif loss_type == "stft" and STFTLoss:
        return STFTLoss()
    elif loss_type == "mrstft" and MRSTFTLoss:
        return MRSTFTLoss()
    elif loss_type == "tf_loudness" and TFLoudnessRatio:
        return TFLoudnessRatio(sample_rate=sample_rate)
    else:
        raise ValueError(f"Unknown or unavailable audio loss type: {loss_type}")

def get_code_loss(loss_type):
    if loss_type == "mse":
        return torch.nn.MSELoss()
    elif loss_type == "l1":
        return torch.nn.L1Loss()
    else:
        raise ValueError(f"Unknown code loss type: {loss_type}")


class EncodecFTWrapper(nn.Module):
    """
    Wrapper around an EncodecModel and a frozen replica.
    Handles different transformers versions (Tensor vs Output objects, quantizer.forward vs quantizer.encode).
    """

    def __init__(self, model: EncodecModel, replica: EncodecModel, augmenter=None, augmentation_start: int = -1):
        super().__init__()
        self.model = model
        self.replica = replica
        self.augmenter = augmenter
        self.augmentation_start = augmentation_start

    def _get_tensor(self, output, attr_name=None, idx=0):
        """Helper to extract tensor from varied output formats."""
        if isinstance(output, torch.Tensor):
            return output
        if attr_name and hasattr(output, attr_name):
            return getattr(output, attr_name)
        if isinstance(output, (tuple, list)):
            return output[idx]
        return output

    def _quantize(self, quantizer, embs_pre_q):
        """
        Robust quantizer call.
        Tries direct call (forward), falls back to encode+decode (transformers specific).
        Returns: embs_post_q
        """
        # Try direct call (newer/standard Modules)
        try:
            out = quantizer(embs_pre_q)
            # If successful, out usually contains quantized latents at idx 0 or .quantized_inputs
            if hasattr(out, "quantized_inputs"):
                return out.quantized_inputs
            elif hasattr(out, "latents"):
                return out.latents
            elif isinstance(out, (tuple, list)):
                return out[0]
            elif isinstance(out, torch.Tensor):
                return out
        except NotImplementedError:
            # Fallback for transformers versions where forward() is not implemented
            pass
        except Exception:
            # Other errors, try fallback
            pass

        # Fallback: Explicit encode -> decode flow
        # 1. Encode to codes
        encoded = quantizer.encode(embs_pre_q)
        
        codes = self._get_tensor(encoded, "audio_codes", 0)
        scales = self._get_tensor(encoded, "audio_scales", 1)
        # Scales might be None or missing in tuple if not used
        if isinstance(scales, torch.Tensor) and scales.numel() == 0:
            scales = None

        # 2. Decode to embeddings (post-q)
        # The decode method usually takes (codes, scales)
        # Some versions might require just codes if scales are None
        try:
            embs_post_q = quantizer.decode(codes, scales)
        except TypeError:
            # Retry without scales if signature mismatch
            embs_post_q = quantizer.decode(codes)
            
        return embs_post_q

    def forward(self, audio: torch.Tensor, epoch: int = -1):
        # 1. Target Generation (Replica)
        with torch.no_grad():
            # Encoder
            enc_out = self.replica.encoder(audio)
            embs_pre_q_replica = self._get_tensor(enc_out, "last_hidden_state", 0)
            
            # Quantizer (Pre-Q -> Post-Q)
            embs_post_q_replica = self._quantize(self.replica.quantizer, embs_pre_q_replica)

        # 2. Audio Prediction (Model Decoder using Target Latents)
        dec_out = self.model.decoder(embs_post_q_replica)
        audio_recon_pred = self._get_tensor(dec_out, "audio_values", 0)

        # 3. Augmentation
        if self.augmenter is not None and (self.augmentation_start < 0 or epoch >= self.augmentation_start):
            res = self.augmenter(audio_recon_pred)
            if isinstance(res, tuple):
                 audio_recon_pred_aug = res[0]
                 selected_aug = res[2] if len(res) > 2 else "unknown"
            else:
                 audio_recon_pred_aug = res
                 selected_aug = "identity"
        else:
            audio_recon_pred_aug = audio_recon_pred
            selected_aug = "identity"

        # 4. Latent Prediction (Model Encoder)
        enc_out_pred = self.model.encoder(audio_recon_pred_aug)
        embs_pre_q_pred = self._get_tensor(enc_out_pred, "last_hidden_state", 0)
        
        # We also compute post-q for the prediction to support 'post_q' loss targets
        embs_post_q_pred = self._quantize(self.model.quantizer, embs_pre_q_pred)

        return {
            "audio_recon_pred": audio_recon_pred,
            "audio_recon_pred_aug": audio_recon_pred_aug,
            "embs_pre_q_target": embs_pre_q_replica,
            "embs_post_q_target": embs_post_q_replica,
            "embs_pre_q_pred": embs_pre_q_pred,
            "embs_post_q_pred": embs_post_q_pred,
            "selected_aug": selected_aug,
        }


def train_one_epoch(wrapper, dataloader, optimizer, device, epoch, steps_per_epoch, 
                    audio_loss_fn, code_loss_fn, audio_loss_weight, code_loss_weight,
                    audio_target_type="replica", code_target_type="pre_q", accum_steps=1):
    wrapper.train()
    total_loss = 0.0
    total_audio_loss = 0.0
    total_code_loss = 0.0
    steps = 0
    
    optimizer.zero_grad()

    for i, audio in enumerate(dataloader):
        if i >= steps_per_epoch:
            break
        audio = audio.to(device)
        # optimizer.zero_grad()
        
        out = wrapper(audio, epoch=epoch)
        
        # Audio Loss
        pred_audio = out["audio_recon_pred"]
        if audio_target_type == "replica":
            # Generate audio target from replica decoder using replica latents
            # We already have embs_post_q_target, so we can decode it with replica
            with torch.no_grad():
                dec_t = wrapper.replica.decoder(out["embs_post_q_target"])
                target_audio = wrapper._get_tensor(dec_t, "audio_values", 0)
        elif audio_target_type == "original":
            target_audio = audio
        else:
            target_audio = audio # Fallback
            
        if pred_audio.shape != target_audio.shape:
             m = min(pred_audio.shape[-1], target_audio.shape[-1])
             pred_audio = pred_audio[..., :m]
             target_audio = target_audio[..., :m]
             
        audio_loss = audio_loss_fn(pred_audio, target_audio)

        # Code Loss
        if code_target_type == "pre_q":
            code_pred = out["embs_pre_q_pred"]
            code_target = out["embs_pre_q_target"].detach()
        elif code_target_type == "post_q":
            code_pred = out["embs_post_q_pred"]
            code_target = out["embs_post_q_target"].detach()
        else:
            # indices logic omitted, fallback to pre_q
            code_pred = out["embs_pre_q_pred"]
            code_target = out["embs_pre_q_target"].detach()
            
        code_loss = code_loss_fn(code_pred, code_target)
        
        loss = (audio_loss_weight * audio_loss) + (code_loss_weight * code_loss)

        loss = loss / accum_steps
        
        loss.backward()

        if (i + 1) % accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += float(loss.detach().cpu())
        total_audio_loss += float(audio_loss.detach().cpu())
        total_code_loss += float(code_loss.detach().cpu())
        steps += 1
        
    return {
        "loss": total_loss / max(1, steps),
        "audio_loss": total_audio_loss / max(1, steps),
        "code_loss": total_code_loss / max(1, steps),
        "steps": steps
    }


def eval_one_epoch(wrapper, dataloader, device, epoch, eval_steps, 
                   audio_loss_fn, code_loss_fn, audio_loss_weight, code_loss_weight):
    wrapper.eval()
    total_audio_loss = 0.0
    total_code_loss = 0.0

    total_token_match_per_ch = None
    total_token_elems_per_ch = None

    steps = 0
    
    with torch.no_grad():
        for i, audio in enumerate(dataloader):
            if i >= eval_steps:
                break
            audio = audio.to(device)
            out = wrapper(audio, epoch=epoch)
            
            # ... (Audio and Code Loss logic remains same) ...
            total_audio_loss += float(audio_loss_fn(out["audio_recon_pred"][..., :audio.shape[-1]], audio[..., :out["audio_recon_pred"].shape[-1]]).cpu())
            total_code_loss += float(code_loss_fn(out["embs_pre_q_pred"], out["embs_pre_q_target"]).cpu())

            # --- Token Match (Idempotence) ---
            enc_t = wrapper.replica.encode(audio)
            codes_target = enc_t.audio_codes if hasattr(enc_t, "audio_codes") else enc_t[0]

            enc_p = wrapper.model.encode(out["audio_recon_pred"])
            codes_pred = enc_p.audio_codes if hasattr(enc_p, "audio_codes") else enc_p[0]

            if isinstance(codes_target, (list, tuple)): codes_target = codes_target[0]
            if isinstance(codes_pred, (list, tuple)): codes_pred = codes_pred[0]

            # Align shapes
            ndim = min(codes_target.ndim, codes_pred.ndim)
            slices = tuple(slice(0, min(codes_target.shape[d], codes_pred.shape[d])) for d in range(ndim))
            
            c_target = codes_target[slices] # [1, 8, 4, 150]
            c_pred = codes_pred[slices]     # [1, 8, 4, 150]

            # matches shape: [1, 8, 4, 150]
            matches = (c_target == c_pred).float()
            
            if total_token_match_per_ch is None:
                # IMPORTANT: nc is at index 2 for shape [1, 8, 4, 150]
                nc = c_target.shape[2] 
                total_token_match_per_ch = torch.zeros(nc, device=device)
                total_token_elems_per_ch = torch.zeros(nc, device=device)

            # FIX: Sum over dims 0 (frames), 1 (batch), and 3 (sequence)
            total_token_match_per_ch += matches.sum(dim=(0, 1, 3))
            
            # FIX: Total elements per channel is Product of dims 0, 1, and 3
            num_elements = c_target.shape[0] * c_target.shape[1] * c_target.shape[3]
            total_token_elems_per_ch += num_elements

            steps += 1

    # Compute final rates
    token_match_rates = (total_token_match_per_ch / total_token_elems_per_ch).cpu().tolist()
    avg_token_match = sum(token_match_rates) / len(token_match_rates)

    results = {
        "eval/audio_loss": total_audio_loss / max(1, steps),
        "eval/code_loss": total_code_loss / max(1, steps),
        "eval/token_match": avg_token_match,
        # "steps": steps
    }
    
    for i, rate in enumerate(token_match_rates):
        results[f"eval/token_match_ch{i}"] = rate

    return results


def save_checkpoint(output_dir: Path, epoch: int, model: EncodecModel, optimizer, scheduler=None):
    sd = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        # "optimizer_state": optimizer.state_dict(),
    }
    if scheduler is not None:
        try:
            sd["scheduler_state"] = scheduler.state_dict()
        except Exception:
            pass
    torch.save(sd, output_dir / f"checkpoint_epoch_{epoch}.pt")


def main():
    parser = argparse.ArgumentParser(description="Fine-tune EnCodec (wmar-style) with token-match eval and wandb logging")

    # from finetune_mimi.py CLI surface (kept names & defaults similar)
    parser.add_argument("--hf_repo", type=str, default="facebook/encodec_32khz", help="HuggingFace repo for model (encodec)")
    parser.add_argument("--output_dir", type=str, default="output", help="Directory to save outputs")

    # Dataset arguments
    parser.add_argument("--audio_dir", type=str, default=None, help="Directory containing audio files")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
    parser.add_argument("--target_sr", type=int, default=24000, help="Target sample rate")
    parser.add_argument("--target_duration", type=float, default=5.0, help="Target audio duration in seconds")
    parser.add_argument("--num_workers", type=int, default=16, help="Number of dataloader workers")
    parser.add_argument("--num_valid", type=int, default=100, help="Number of validation samples")
    parser.add_argument("--accum_steps", type=int, default=1, help="Gradient accumulation steps")

    # Training arguments
    parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate")
    parser.add_argument("--epochs", type=int, default=200, help="Number of training epochs")
    parser.add_argument("--warmup_epochs", type=int, default=5, help="Number of warmup epochs (not used)")
    parser.add_argument("--steps_per_epoch", type=int, default=100, help="Maximum iterations per epoch")

    # Losses & weights
    parser.add_argument("--code_loss_type", type=str, default="mse", help="Code loss type")
    parser.add_argument("--audio_loss_type", type=str, default="mrstft", help="Audio loss type")
    parser.add_argument("--audio_loss_weight", type=float, default=1e-3, help="Weight for audio reconstruction loss")
    parser.add_argument("--code_loss_weight", type=float, default=1.0, help="Weight for code reconstruction loss")
    parser.add_argument("--audio_target_type", type=str, default="replica", help="Target for audio loss ('replica' or 'original')")
    parser.add_argument("--code_target_type", type=str, default="pre_q", help="Target for code loss ('post_q', 'pre_q', or indices)")

    # Fine-tuning specifics
    parser.add_argument("--resume_from", type=str, default=None, help="Path to a checkpoint to resume training from")
    parser.add_argument("--finetune_encoder", type=bool, default=True, help="Fine-tune encoder")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to train on")
    parser.add_argument("--save_freq", type=int, default=10, help="Frequency of saving checkpoints (epochs)")
    parser.add_argument("--eval_freq", type=int, default=1, help="Frequency of evaluation (epochs)")
    parser.add_argument("--seed", type=int, default=42424242, help="Random seed")

    # Augmentation
    parser.add_argument("--augmentation_start", type=int, default=-1, help="Epoch to start applying augmentations")
    parser.add_argument("--augs", type=str, default="{}", help="JSON dict of augmentation weights")
    parser.add_argument("--augs_params", type=str, default="{}", help="JSON dict of augmentation parameters")
    parser.add_argument("--num_augmentations", type=int, default=1, help="Number of augmentations to apply sequentially")

    # Distributed / miscellaneous (accepted for CLI parity)
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training")
    parser.add_argument("--master_port", type=int, default=-1, help="Master port for DDP")
    parser.add_argument("--debug_slurm", type=bool, default=False, help="Debug SLURM setup")
    parser.add_argument("--distributed", action="store_true", help="Whether to treat as distributed run (no full DDP here)")

    # Wandb logging
    parser.add_argument("--use_wandb", action="store_true", help="Log metrics to wandb")
    parser.add_argument("--wandb_project", type=str, default="wmar-encodec", help="wandb project name")
    parser.add_argument("--wandb_run_name", type=str, default=None, help="wandb run name")

    args = parser.parse_args()

    if args.audio_dir is None:
        args.audio_dir = [
            "/storage/data/LibriTTS",
            "/storage/data/FMA"
        ]

    # Get the local rank from torchrun
    import torch.distributed as dist
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend="nccl")

    # normalize encodec repo arg
    encodec_repo = args.hf_repo

    # Create unique output directory using timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    outdir = Path(args.output_dir) / timestamp
    outdir.mkdir(parents=True, exist_ok=True)
    print(f"Output directory set to: {outdir}")

    # parse augmentation JSONs (no heavy validation)
    augs = json.loads(args.augs.replace("'", '"')) if args.augs is not None else {}
    augs_params = json.loads(args.augs_params.replace("'", '"')) if args.augs_params is not None else {}

    device = torch.device(args.device)
    print("Using device:", device)

    # optionally initialize wandb
    if args.use_wandb and is_main_process:
        import wandb  # user must have wandb installed and configured
        wandb.init(entity="el18035", project=args.wandb_project, name=args.wandb_run_name, dir="outputs/finetune/wandb")
        wandb.config.update(vars(args))

    # Load EnCodec
    print("Loading EnCodec from", encodec_repo)
    encodec = EncodecModel.from_pretrained(encodec_repo)
    encodec.to(device)

    # Create frozen replica
    encodec_replica = deepcopy(encodec)
    encodec_replica.eval()
    for p in encodec_replica.parameters():
        p.requires_grad = False

    # Freeze quantizer if present (best-effort)
    if hasattr(encodec, "quantizer"):
        for p in encodec.quantizer.parameters():
            p.requires_grad = False

    # Build augmenter & wrapper
    augmenter = Augmenter(augs=augs, augs_params=augs_params, num_augs=args.num_augmentations, sample_rate=args.target_sr) if augs else None
    wrapper = EncodecFTWrapper(encodec, encodec_replica, augmenter=augmenter, augmentation_start=args.augmentation_start)
    wrapper.to(device)

    # Dataset + dataloaders
    full_dataset = AudioDataset(args.audio_dir, target_sr=args.target_sr, target_duration=args.target_duration)
    total_size = len(full_dataset)
    num_valid = min(args.num_valid, total_size)
    train_size = max(0, total_size - num_valid)
    train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [train_size, num_valid])
    train_dataloader = get_audio_dataloader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, distributed=args.distributed)
    valid_dataloader = get_audio_dataloader(valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, distributed=args.distributed)

    # Optimizer: decode params + optionally encoder
    params_to_opt = list(getattr(encodec, "decoder", encodec).parameters())
    if args.finetune_encoder and hasattr(encodec, "encoder"):
        params_to_opt += list(encodec.encoder.parameters())
    optimizer = AdamW(params_to_opt, lr=args.learning_rate)

    # --- ADD SCHEDULER INITIALIZATION HERE ---
    # 1. Linear Warmup (start small, go to base LR over 5 epochs)
    scheduler_warmup = LinearLR(optimizer, start_factor=0.01, total_iters=args.warmup_epochs)
    
    # 2. Cosine Annealing (base LR -> 2e-7 over remaining epochs)
    decay_epochs = args.epochs - args.warmup_epochs
    scheduler_cosine = CosineAnnealingLR(optimizer, T_max=decay_epochs, eta_min=1e-8)
    
    # 3. Combine them
    scheduler = SequentialLR(
        optimizer, 
        schedulers=[scheduler_warmup, scheduler_cosine], 
        milestones=[args.warmup_epochs]
    )
    print(f"Initialized Scheduler: Warmup ({args.warmup_epochs} eps) -> Cosine Decay (to 2e-7)")
    # -----------------------------------------

    # Audio loss
    audio_loss_fn = get_audio_loss(args.audio_loss_type, sample_rate=args.target_sr).to(device)
    code_loss_fn = get_code_loss(args.code_loss_type).to(device)

    # resume if requested
    start_epoch = 0
    if args.resume_from is not None:
        ck = torch.load(args.resume_from, map_location="cpu")
        try:
            encodec.load_state_dict(ck["model_state"])
        except Exception:
            encodec.load_state_dict(ck)
        optimizer.load_state_dict(ck.get("optimizer_state", {}))
        start_epoch = ck.get("epoch", 0) + 1
        print("Resumed from", args.resume_from, "starting at epoch", start_epoch)

    # training loop with wandb logging and token-match eval
    for epoch in range(start_epoch, args.epochs):
        t0 = time.time()
        train_stats = train_one_epoch(
            wrapper, train_dataloader, optimizer, device, epoch, args.steps_per_epoch, 
            audio_loss_fn, code_loss_fn, args.audio_loss_weight, args.code_loss_weight,
            audio_target_type=args.audio_target_type, code_target_type=args.code_target_type,
            accum_steps=args.accum_steps
        )
        print(f"Epoch {epoch} train:", train_stats, "elapsed:", time.time() - t0)

        if args.use_wandb and is_main_process:
            import wandb
            wandb.log({
                "train/loss": train_stats["loss"], 
                "train/audio_loss": train_stats["audio_loss"],
                "train/code_loss": train_stats["code_loss"],
                # "train/steps": train_stats["steps"]
            }, step=epoch)

        if (epoch + 1) % args.eval_freq == 0:
            val_stats = eval_one_epoch(
                wrapper, valid_dataloader, device, epoch, min(args.num_valid, args.steps_per_epoch), 
                audio_loss_fn, code_loss_fn, args.audio_loss_weight, args.code_loss_weight
            )
            print(f"Epoch {epoch} eval:", val_stats)
            if args.use_wandb and is_main_process:
                import wandb
                wandb.log(val_stats, step=epoch)

        # --- ADD SCHEDULER STEP HERE ---
        scheduler.step()
        # Optionally log the current LR
        if args.use_wandb and is_main_process:
            wandb.log({"train/lr": optimizer.param_groups[0]["lr"]}, step=epoch)
        # -------------------------------

        if (epoch + 1) % args.save_freq == 0:
            if dist.get_rank() == 0:
                save_checkpoint(outdir, epoch, encodec, optimizer)

            try:
                encodec.config.save_pretrained(str(outdir))
            except Exception:
                pass
            try:
                AutoFeatureExtractor.from_pretrained(encodec_repo).save_pretrained(str(outdir))
            except Exception:
                pass
            print("Saved epoch checkpoint", epoch)

    # final save
    if dist.get_rank() == 0:
        save_checkpoint(outdir, args.epochs - 1, encodec, optimizer)
    print("Done.")


if __name__ == "__main__":
    main()
