#!/usr/bin/env python
"""
Accelerate-based, Batched & Sharded preprocessor for CC3M/MSCOCO

Key features
- Uses 🤗 Accelerate for device placement, DDP, and mixed precision
- High-throughput batching with DataLoader workers and pinned memory
- Saves large shard files (`shard_XXXXXX_rankR.pt`) vs. millions of tiny files
- Handles multiple captions per sample (pooling or store-all)
- Safe, atomic saves; resumable at shard granularity per-rank

Launch examples
--------------
Single GPU:
    accelerate launch precompute_features_accel.py \
      --dataset cc3m --split train --resolution 256 \
      --batch_size 128 --num_workers 16 --pin_memory \
      --shard_size 20000 --mixed_precision fp16 --save_images

Multi-GPU (single node, 8 GPUs):
    accelerate launch --multi_gpu --num_processes=8 precompute_features_accel.py \
      --dataset cc3m --split train --resolution 256 \
      --batch_size 128 --num_workers 16 --pin_memory \
      --shard_size 20000 --mixed_precision fp16

On SLURM (example):
    srun --ntasks=8 --gpus-per-task=1 --cpus-per-task=8 \
      accelerate launch --multi_gpu precompute_features_accel.py ...

Outputs per shard (.pt):
{
  'indices': List[int],            # dataset-global indices covered in this shard
  'moments': FloatTensor [N, ...], # autoencoder moments per image (fp32 on disk)
  'clip': FloatTensor [N, D] | List[Tensor], # pooled features or list per-sample
  'captions': List[List[str]],     # list of caption lists
}
"""

import os
import re
import math
import argparse
from typing import Any, Dict, List, Tuple

import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from accelerate import Accelerator

# Project imports
from .libs.autoencoder import get_model
from .libs.clip import FrozenCLIPEmbedder
from .datasets import MSCOCODatabase, CC3MDataset

DATASET = {
    'mscoco': {
        'dataset_class': MSCOCODatabase,
        'root_dir': '/gpfs/projects/bsc70/hpai/storage/data/datasets/raw/MSCOCO',
        'annFile': '/gpfs/projects/bsc70/hpai/storage/data/datasets/raw/MSCOCO/annotations'
    },
    'cc3m': {
        'dataset_class': CC3MDataset,
        'root_dir': '/gpfs/projects/bsc70/bsc131047/data/cc3m'
    }
}


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--dataset', choices=['mscoco', 'cc3m'], default='mscoco')
    p.add_argument('--split', choices=['train', 'val'], default='train')
    p.add_argument('--resolution', type=int, default=256)

    # Performance knobs
    p.add_argument('--batch_size', type=int, default=64)
    p.add_argument('--num_workers', type=int, default=8)
    p.add_argument('--prefetch_factor', type=int, default=4)
    p.add_argument('--persistent_workers', action='store_true')
    p.add_argument('--pin_memory', action='store_true')

    # Sharding & output
    p.add_argument('--shard_size', type=int, default=10000)
    p.add_argument('--out_root', type=str, default='/gpfs/projects/bsc70/bsc193242/Data')
    p.add_argument('--save_images', action='store_true')
    p.add_argument('--images_subdir', type=str, default='images')
    p.add_argument('--overwrite', action='store_true')

    # Precision
    p.add_argument('--mixed_precision', choices=['no','fp16','bf16'], default='no',
                   help='Accelerate mixed precision mode')
    p.add_argument('--vae_on_disk_fp32', action='store_true',
                   help='Force storing VAE moments as fp32 (default: True)')

    # CLIP options
    p.add_argument('--clip_pool', choices=['mean','sum','none'], default='mean')
    p.add_argument('--store_all_clip', action='store_true',
                   help='Store a variable-length list of embeddings per sample instead of pooling')

    # Models
    p.add_argument('--autoencoder_ckpt', type=str,
                   default='/gpfs/projects/bsc70/bsc193242/Models/stable-diffusion/autoencoder_kl.pth')

    return p.parse_args()


def build_dataset(args):
    if args.dataset == 'cc3m':
        root = f"{DATASET['cc3m']['root_dir']}/{args.split}"
        ds = DATASET['cc3m']['dataset_class'](root_dir=root, size=args.resolution, augmentation=False)
        out_dir = os.path.join(args.out_root, f"cc3m{args.resolution}_features", args.split)
    elif args.dataset == 'mscoco':
        if args.split == 'train':
            ds = DATASET['mscoco']['dataset_class'](
                root=f"{DATASET['mscoco']['root_dir']}/train2017",
                annFile=f"{DATASET['mscoco']['annFile']}/captions_train2017.json",
                size=args.resolution,
            )
        else:
            ds = DATASET['mscoco']['dataset_class'](
                root=f"{DATASET['mscoco']['root_dir']}/val2014",
                annFile=f"{DATASET['mscoco']['annFile']}/captions_val2014.json",
                size=args.resolution,
            )
        out_dir = os.path.join(args.out_root, f"coco{args.resolution}_features", args.split)
    else:
        raise NotImplementedError
    os.makedirs(out_dir, exist_ok=True)
    return ds, out_dir


def init_models(args, device):
    vae = get_model(args.autoencoder_ckpt).to(device)
    clip = FrozenCLIPEmbedder().eval().to(device)
    return vae, clip


def make_loader(ds, args):
    loader = DataLoader(
        ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory,
        persistent_workers=args.persistent_workers,
        prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
        drop_last=False,
    )
    return loader


def find_next_shard_id(out_dir: str, rank: int) -> Tuple[int, List[str]]:
    """Find existing shards for this rank and return next id and list of all shard files."""
    # Match both rank-suffixed and rank-less files
    rank_pat = re.compile(rf'^shard_(\d{{6}})_rank{rank}\.pt$')
    all_pat = re.compile(r'^shard_(\d{6})(?:_rank\d+)?\.pt$')
    this_rank_files, all_files = [], []
    for f in os.listdir(out_dir):
        if all_pat.match(f):
            all_files.append(f)
        if rank_pat.match(f):
            this_rank_files.append(f)
    next_id = 0
    if this_rank_files:
        max_id = max(int(rank_pat.match(f).group(1)) for f in this_rank_files)
        next_id = max_id + 1
    return next_id, all_files


def atomic_torch_save(obj: Dict[str, Any], path: str):
    tmp = path + '.tmp'
    torch.save(obj, tmp)
    os.replace(tmp, path)


def ensure_img_dir(out_dir: str, subdir: str) -> str:
    img_dir = os.path.join(out_dir, subdir)
    os.makedirs(img_dir, exist_ok=True)
    return img_dir


def main():
    args = parse_args()

    accelerator = Accelerator(mixed_precision=None if args.mixed_precision == 'no' else args.mixed_precision)
    device = accelerator.device
    rank = accelerator.process_index
    world_size = accelerator.num_processes

    ds, out_dir = build_dataset(args)

    if accelerator.is_main_process:
        print(f"Dataset size: {len(ds):,}")
        print(f"Output dir: {out_dir}")
        if args.overwrite:
            print("[WARN] --overwrite set; existing shards with same id will be replaced if encountered.")

    loader = make_loader(ds, args)
    vae, clip = init_models(args, device)

    # Prepare with accelerate (moves models and loader to proper devices/process groups)
    vae, clip, loader = accelerator.prepare(vae, clip, loader)

    next_shard_id, all_shards = find_next_shard_id(out_dir, rank)
    if accelerator.is_main_process:
        print(f"Found {len(all_shards)} total shard files. Rank {rank} will start at shard id {next_shard_id:06d}.")

    # Image dir (only main process writes images to avoid duplication)
    img_dir = ensure_img_dir(out_dir, args.images_subdir) if (args.save_images and accelerator.is_main_process) else None

    # Buffers
    buf_indices: List[int] = []
    buf_captions: List[List[str]] = []
    buf_moments: List[torch.Tensor] = []
    buf_clip: List[Any] = []  # Tensor if pooled, or list[Tensor] if store_all_clip

    seen_global = 0  # dataset-global counter; identical on all ranks because shuffle=False

    # Mixed precision context from accelerate
    vae_cast = accelerator.autocast()
    clip_cast = accelerator.autocast()

    # Progress bar on each rank; only rank 0 prints a nice bar
    total_batches = math.ceil(len(ds) / max(1, args.batch_size))
    pbar = tqdm(total=total_batches, disable=not accelerator.is_main_process)

    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            images, captions = batch  # images: [B,C,H,W], captions: List[str] or List[List[str]]

            # Standardize captions to List[List[str]]
            if isinstance(captions, (list, tuple)) and len(captions) > 0 and isinstance(captions[0], str):
                captions = [list([c]) for c in captions]

            B = images.size(0)

            if img_dir is not None:
                # Only main process writes images
                for i in range(B):
                    save_image(images[i].cpu(), os.path.join(img_dir, f"{seen_global + i}.png"))

            # VAE encode -> moments
            with vae_cast:
                moments = vae(images, fn='encode_moments')  # [B, ...]
            # Store as fp32 on disk for numerical safety unless user opts out
            moments_store = moments.float().cpu() if (args.vae_on_disk_fp32 or True) else moments.cpu()
            buf_moments.append(moments_store)

            # Flatten captions for CLIP
            flat_caps: List[str] = []
            offsets: List[Tuple[int,int]] = []
            cursor = 0
            for caps in captions:
                n = len(caps)
                flat_caps.extend(caps)
                offsets.append((cursor, cursor + n))
                cursor += n

            with clip_cast:
                clip_feats = clip.encode(flat_caps)  # [sum_n, D]

            if args.store_all_clip:
                per_sample = []
                for s, e in offsets:
                    per_sample.append(clip_feats[s:e].float().cpu())
                buf_clip.extend(per_sample)
            else:
                pooled = []
                for s, e in offsets:
                    feats = clip_feats[s:e]
                    if args.clip_pool == 'mean':
                        pooled.append(feats.mean(dim=0))
                    elif args.clip_pool == 'sum':
                        pooled.append(feats.sum(dim=0))
                    else:
                        pooled.append(feats[0])
                pooled = torch.stack(pooled, dim=0)  # [B, D]
                buf_clip.append(pooled.float().cpu())

            buf_indices.extend(list(range(seen_global, seen_global + B)))
            buf_captions.extend([list(c) for c in captions])
            seen_global += B

            # Flush shard if needed (per-rank)
            if len(buf_indices) >= args.shard_size:
                shard_obj = {
                    'indices': buf_indices,
                    'moments': torch.cat(buf_moments, dim=0),
                    'clip': buf_clip if args.store_all_clip else torch.cat(buf_clip, dim=0),
                    'captions': buf_captions,
                }
                shard_name = os.path.join(out_dir, f"shard_{next_shard_id:06d}_rank{rank}.pt")
                if accelerator.is_main_process:
                    # Main process logs; all ranks still save their own shards
                    print(f"Saving shard {next_shard_id:06d} (rank {rank}) with {len(buf_indices)} samples -> {shard_name}")
                atomic_torch_save(shard_obj, shard_name)
                next_shard_id += 1
                buf_indices, buf_captions, buf_moments, buf_clip = [], [], [], []

            pbar.update(1)

        # Final flush
        if len(buf_indices) > 0:
            shard_obj = {
                'indices': buf_indices,
                'moments': torch.cat(buf_moments, dim=0),
                'clip': buf_clip if args.store_all_clip else torch.cat(buf_clip, dim=0),
                'captions': buf_captions,
            }
            shard_name = os.path.join(out_dir, f"shard_{next_shard_id:06d}_rank{rank}.pt")
            if accelerator.is_main_process:
                print(f"Saving FINAL shard {next_shard_id:06d} (rank {rank}) with {len(buf_indices)} samples -> {shard_name}")
            atomic_torch_save(shard_obj, shard_name)

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        print('All ranks finished. Done.')


if __name__ == '__main__':
    main()
