import os
import gc
import json
import glob
import argparse
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM
import numpy as np

try:
    from llava.model.builder import load_pretrained_model
    from llava.mm_utils import get_model_name_from_path
except ImportError:
    pass 

from dataset import ExtractionDataset, DataCollator
from util import DataArguments, compute_token_loss
from tqdm import tqdm

def setup_ddp():
    # for single gpu
    if "RANK" not in os.environ:
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12355"
        os.environ["LOCAL_RANK"] = "0"

    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank

def cleanup_ddp():
    dist.destroy_process_group()

def main(args):
    # -------- DDP init --------
    local_rank = setup_ddp()
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f"cuda:{local_rank}")

    # -------- perf knobs --------
    # TF32 is a big win on Ampere+ for matmuls (often safe for inference-like scoring).
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

    # =========================
    # 1. Model
    # =========================
    # -------- model paths --------
    model_path = args.model_path
    if "7b" in model_path:
        mm_proj_ckpt = "checkpoints/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin"
        llm_ckpt = "lmsys/vicuna-7b-v1.5"
    else:
        mm_proj_ckpt = "checkpoints/llava-v1.5-mlp2x-336px-pretrain-vicuna-13b-v1.5/mm_projector.bin"
        llm_ckpt = "lmsys/vicuna-13b-v1.5"

    # -------- load LLaVA (vision part) --------
    model_name = get_model_name_from_path(model_path)
    
    # Load base LLaVA
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path,
        None,
        model_name,
        False,
        False,
        device_map=None, # manually mapping later
        torch_dtype=torch.float16,
    )

    # Prefer fast attention if available (best-effort; may be ignored by the model)
    try:
        if hasattr(model.config, "attn_implementation"):
            model.config.attn_implementation = "flash_attention_2"
    except Exception:
        pass
    
    # change MM Projector weigtht
    print(f"[Rank {rank}] Swapping MM Projector...")
    mm_projector = torch.load(mm_proj_ckpt, map_location="cpu")
    new_sd = {
        k.replace("model.mm_projector.", ""): v
        for k, v in mm_projector.items()
        if k.startswith("model.mm_projector.")
    }
    model.model.mm_projector.load_state_dict(new_sd)
    del mm_projector # 메모리 해제
    gc.collect()

    # change LLM Part weight\
    print(f"[Rank {rank}] Swapping LLM...")
    for sub in ["embed_tokens", "lm_head", "layers", "norm", "rotary_emb"]:
        if hasattr(model.model, sub):
            setattr(model.model, sub, None)
    
    torch.cuda.empty_cache()
    gc.collect()

    # Load new LLM
    llm = AutoModelForCausalLM.from_pretrained(
        llm_ckpt,
        device_map='cpu',
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True, # Optimize loading and RAM
    )

    model.model.embed_tokens = llm.model.embed_tokens
    model.model.lm_head = llm.lm_head
    model.model.layers = llm.model.layers
    model.model.norm = llm.model.norm
    model.model.rotary_emb = llm.model.rotary_emb
    
    # delete LLM object
    del llm
    gc.collect()
    torch.cuda.empty_cache()

    # Move to GPU
    model.to(device)
    model.eval()
    # For full-sequence scoring, KV cache usually doesn't help and can increase memory.
    model.config.use_cache = False
    model.config.vig_extraction = True
    # Reuse loss object (avoid per-step allocations)
    criterion = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=-100)
    np_loss_dtype = torch.float16 if args.loss_dtype == "fp16" else torch.float32

    # =========================
    # 2. dataset / loader
    # =========================
    data_args = DataArguments()
    data_args.data_path = args.data_path
    # Dataset perf options
    data_args.image_folder = args.image_folder
    data_args.cache_dir = args.cache_dir
    data_args.cache_preprocessed = args.cache_preprocessed

    dataset = ExtractionDataset(
        image_processor=image_processor,
        tokenizer=tokenizer,
        args=data_args,
    )

    sampler = DistributedSampler(
        dataset,
        shuffle=False, 
        drop_last=False
    )

    data_collator = DataCollator(
        tokenizer=tokenizer
    )

    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory,
        persistent_workers=(args.num_workers > 0),
        prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
        collate_fn=data_collator,
    )

    # Optional: build cache only (skip any model compute)
    if args.build_cache_only:
        if rank == 0:
            print("Building dataset cache only (no model forward)...")
            if not args.cache_preprocessed or not args.cache_dir:
                print("⚠️  --build_cache_only requires --cache_preprocessed and --cache_dir.")
        iterator = tqdm(loader, disable=(rank != 0), desc="Cache build")
        with torch.inference_mode():
            for data in iterator:
                _ = data["image"]  # touch to trigger dataset-side caching
        dist.barrier()
        if rank == 0:
            print("Cache build done.")
        cleanup_ddp()
        return
    
    # =========================
    # 3. inference for VIG
    # =========================
    # shard_dir가 없으면 임시 디렉토리 자동 생성 (OOM 방지)
    is_temp_shard_dir = False
    if not args.shard_dir:
        import tempfile
        args.shard_dir = tempfile.mkdtemp(prefix="vig_shards_")
        is_temp_shard_dir = True
        if rank == 0:
            print(f"⚠️  --shard_dir not specified. Using temporary directory: {args.shard_dir}")
            print("   (This will be cleaned up after merge. Use --shard_dir to keep files.)")
    
    local_results = []
    local_np = []

    def flush_parts():
        nonlocal local_results, local_np
        if not args.shard_dir:
            return
        if len(local_results) == 0 and len(local_np) == 0:
            return
        os.makedirs(args.shard_dir, exist_ok=True)
        # Write JSONL (stream-friendly)
        if args.save_json:
            jsonl_path = os.path.join(args.shard_dir, f"{args.output_prefix}.rank{rank}.jsonl")
            with open(jsonl_path, "w") as f:
                for row in local_results:
                    f.write(json.dumps(row) + "\n")
        # Write np shard (pickle-allowed dict list)
        np_path = os.path.join(args.shard_dir, f"{args.output_prefix}.rank{rank}.npy")
        np.save(np_path, local_np, allow_pickle=True)
        local_results = []
        local_np = []

    iterator = tqdm(loader, disable=(rank != 0), desc="Inference")

    with torch.inference_mode():
        for data in iterator:
            save_idxs = data["save_idx"]

            # Data to GPU
            images = data["image"].to(device, non_blocking=True, dtype=torch.float16)
            input_ids = data["input_ids"].to(device, non_blocking=True)
            labels = data["labels"].to(device, non_blocking=True)
            attention_mask = data["attention_mask"].to(device, non_blocking=True)

            # Compute Loss
            wloss, wmask, attn_w_mask, woloss, womask, attn_wo_mask = compute_token_loss(
                model,
                input_ids,
                labels,
                images,
                data["image_size"],
                attention_mask,
                criterion=criterion,
            )

            # GPU → CPU
            # Keep fp16 on CPU to reduce transfer/CPU cost; convert when serializing if needed.
            wloss = wloss.cpu()
            woloss = woloss.cpu()
            wmask = wmask.cpu()
            womask = womask.cpu()
            attn_w_mask = attn_w_mask.cpu()
            attn_wo_mask = attn_wo_mask.cpu()

            # 결과 저장
            bs = len(save_idxs)
            for i in range(bs):
                si = int(save_idxs[i])
                # NOTE: attention_mask_ can include image tokens; use it for JSON list.
                wl = wloss[i][attn_w_mask[i]]
                wol = woloss[i][attn_wo_mask[i]]
                local_results.append(
                    {
                        "save_idx": si,  # Global Index
                        "id": data["ids"][i],
                        "image": data["image_name"][i],
                        "conversations": data["convs"][i],
                        "w_loss": wl.float().tolist() if args.save_json else None,
                        "wo_loss": wol.float().tolist() if args.save_json else None,
                    }
                )
                local_np.append(
                    {
                        "save_idx": si,  # Global Index
                        "id": data["ids"][i],
                        "w_loss": wloss[i][wmask[i]].to(np_loss_dtype).numpy(),
                        "wo_loss": woloss[i][womask[i]].to(np_loss_dtype).numpy(),
                    }
                )

    # 전체 inference 완료 후 한 번만 저장
    if args.shard_dir:
        flush_parts()

    dist.barrier()

    # =========================
    # 4. gather (to rank0) & sort & deduplicate
    # =========================
    if args.shard_dir and args.no_gather:
        if rank == 0:
            print(f"Saved sharded outputs under: {args.shard_dir}")
            print("Skipping gather/sort on rank0 (use a separate merge step if needed).")
        cleanup_ddp()
        return

    # 메모리 효율적인 파일 기반 병합 사용 (shard_dir가 있으면)
    if args.shard_dir:
        if rank == 0:
            print("Merging sharded files (memory-efficient)...")
            # 모든 rank의 shard 파일 찾기
            all_np_files = sorted(glob.glob(os.path.join(args.shard_dir, f"{args.output_prefix}.rank*.npy")))
            all_jsonl_files = sorted(glob.glob(os.path.join(args.shard_dir, f"{args.output_prefix}.rank*.jsonl"))) if args.save_json else []
            
            # 메모리 효율적으로 병합: 파일을 하나씩 읽어서 딕셔너리에 저장
            unique_np_map = {}
            unique_results_map = {}
            
            # NPY 파일 병합
            for np_file in tqdm(all_np_files, desc="Merging npy shards"):
                chunk = np.load(np_file, allow_pickle=True).tolist()
                for item in chunk:
                    unique_np_map[item["save_idx"]] = item
                del chunk
                gc.collect()
            
            # JSONL 파일 병합 (선택적)
            if args.save_json:
                for jsonl_file in tqdm(all_jsonl_files, desc="Merging jsonl shards"):
                    with open(jsonl_file, "r") as f:
                        for line in f:
                            item = json.loads(line)
                            unique_results_map[item["save_idx"]] = item
            
            # Sort by Index
            sorted_np = [unique_np_map[i] for i in sorted(unique_np_map.keys())]
            if args.save_json:
                sorted_results = [unique_results_map[i] for i in sorted(unique_results_map.keys())]
            
            print(f"Total dataset size: {len(dataset)}")
            print(f"Collected unique size: {len(sorted_np)}")
            
            if len(dataset) != len(sorted_np):
                print("⚠️ Warning: Collected size does not match dataset size!")

            # Merge with text only dataset
            if args.save_json:
                sorted_results.extend(dataset.text_dataset)
            
            # 저장
            output_prefix = args.output_prefix
            if args.save_json:
                output_file = f"{output_prefix}.json"
                print(f"Saving to {output_file}...")
                with open(output_file, "w") as f:
                    json.dump(sorted_results, f)
            npy_file = f"{output_prefix}_only_label.npy"
            np.save(npy_file, sorted_np, allow_pickle=True)
            print("Done.")
            
            # 임시 디렉토리 정리
            if is_temp_shard_dir:
                import shutil
                print(f"Cleaning up temporary directory: {args.shard_dir}")
                shutil.rmtree(args.shard_dir, ignore_errors=True)
        
        cleanup_ddp()
        return

    # 기존 gather 방식 (이제는 사용되지 않음 - shard_dir가 항상 있음)
    # 만약 이 코드가 실행된다면 OOM 위험이 있으므로 경고
    if rank == 0:
        print("⚠️  WARNING: Using gather_object (high memory usage). This should not happen.")
        print("   shard_dir should always be set. Check the code.")
    
    gathered_results = [None for _ in range(world_size)] if rank == 0 else None
    dist.gather_object(local_results, gathered_results, dst=0)
    gathered_np = [None for _ in range(world_size)] if rank == 0 else None
    dist.gather_object(local_np, gathered_np, dst=0)
    
    # local_results, local_np 즉시 해제
    del local_results, local_np
    gc.collect()
    
    if rank == 0:
        # Flatten list of lists
        flat_results = []
        flat_np = []
        for process_batch in gathered_results:
            if process_batch:
                flat_results.extend(process_batch)
        for process_batch in gathered_np:
            if process_batch:
                flat_np.extend(process_batch)
        
        # gathered_results, gathered_np 즉시 해제
        del gathered_results, gathered_np
        gc.collect()

        # Deduplicate using Dictionary
        unique_results_map = {item["save_idx"]: item for item in flat_results}
        unique_np_map = {item["save_idx"]: item for item in flat_np}
        
        # flat_results, flat_np 즉시 해제
        del flat_results, flat_np
        gc.collect()
        
        # Sort by Index (Global Order Recovery)
        sorted_results = [unique_results_map[i] for i in sorted(unique_results_map.keys())]
        sorted_np = [unique_np_map[i] for i in sorted(unique_np_map.keys())]
        
        # unique_results_map, unique_np_map 즉시 해제
        del unique_results_map, unique_np_map
        gc.collect()

        print(f"Total dataset size: {len(dataset)}")
        print(f"Collected unique size: {len(sorted_results)}")
        
        if len(dataset) != len(sorted_results):
            print("⚠️ Warning: Collected size does not match dataset size!")

        # Merge with text only dataset
        sorted_results.extend(dataset.text_dataset)
        
        # 저장
        output_prefix = args.output_prefix
        if args.save_json:
            output_file = f"{output_prefix}.json"
            print(f"Saving to {output_file}...")
            with open(output_file, "w") as f:
                json.dump(sorted_results, f)
            del sorted_results
            gc.collect()
        npy_file = f"{output_prefix}_only_label.npy"
        np.save(npy_file, sorted_np, allow_pickle=True)
        del sorted_np
        gc.collect()
        print("Done.")

    cleanup_ddp()
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="liuhaotian/llava-v1.5-7b")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--data_path", type=str, default='llava_v1_5_mix665k.json')
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--prefetch_factor", type=int, default=4)
    parser.add_argument("--pin_memory", action="store_true", default=True)
    parser.add_argument("--save_json", action="store_true", default=True)
    parser.add_argument("--output_prefix", type=str, default="llava_1_5_w_vig_value")
    parser.add_argument("--image_folder", type=str, default="data/llava")
    parser.add_argument("--cache_dir", type=str, default=None)
    parser.add_argument("--cache_preprocessed", action="store_true", default=False)
    parser.add_argument("--build_cache_only", action="store_true", default=False)
    parser.add_argument("--loss_dtype", type=str, choices=["fp16", "fp32"], default="fp16")
    # Perf: avoid rank0 gather + stream shards to disk
    parser.add_argument("--shard_dir", type=str, default=None)
    parser.add_argument("--no_gather", action="store_true", default=False)
    parser.add_argument("--flush_every", type=int, default=200)  # in "effective samples" ~= batches*batch_size
    args = parser.parse_args()

    main(args)