import argparse
import gzip
import json
import os
import sys
from pathlib import Path
from typing import List, Tuple

import torch
from transformers import AutoTokenizer

# Ensure megatron and utils_grouter can be imported
_THIS_FILE = Path(__file__).resolve()
_PROJECT_ROOT = _THIS_FILE.parents[2]  # .../general_router
_MEGATRON_ROOT = _PROJECT_ROOT / "Megatron-LM"
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))
if str(_MEGATRON_ROOT) not in sys.path:
    sys.path.insert(0, str(_MEGATRON_ROOT))

from megatron.core.datasets.indexed_dataset import IndexedDataset
from utils_grouter.general_router import grouter


def read_c4_lines(gz_path: str, doc_indices: List[int]) -> List[str]:
    """Read specific JSON lines (0-based indices) from gz C4 file.

    Returns the raw text field for each specified line index.
    """
    doc_set = set(doc_indices)
    max_idx = max(doc_indices)
    results = {}
    with gzip.open(gz_path, 'rt', encoding='utf-8') as fin:
        for i, line in enumerate(fin):
            if i in doc_set:
                try:
                    data = json.loads(line)
                    text = data["text"]
                except Exception:
                    text = ""
                results[i] = text
                if len(results) == len(doc_indices):
                    break
            if i > max_idx and len(results) == len(doc_indices):
                break
    return [results[idx] for idx in doc_indices]


def split_to_sentences(text_or_list) -> List[str]:
    if isinstance(text_or_list, list):
        return text_or_list
    return [text_or_list]


def load_grouter(args, tokenizer):
    if args.grouter_config and os.path.exists(args.grouter_config):
        with open(args.grouter_config, "r") as f:
            grt_config = json.load(f)
        grt = grouter(**grt_config)
    else:
        grt = grouter(args.topk,
                      args.expert_num,
                      tokenizer.pad_token_id,
                      args.scoring_func,
                      args.vocab_size,
                      args.hidden_size,
                      args.construct_type)
    ckpt = torch.load(args.grouter_ckpt, map_location="cpu")
    grt.load_state_dict(ckpt)
    grt.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    grt.to(device)
    return grt, device


def run_grouter_on_sentences(grt, tokenizer, sentences: List[str], device) -> Tuple[List[int], List[int]]:
    enc = tokenizer(
        sentences,
        return_tensors="pt",
        padding=True,
        truncation=False,
        add_special_tokens=False,
    )
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)
    with torch.inference_mode():
        with torch.autocast(device_type=("cuda" if device.type == "cuda" else "cpu"), dtype=torch.bfloat16):
            batch_out = grt(input_ids, attention_mask, None)[0]
    flat: List[int] = []
    lens: List[int] = []
    for i in range(batch_out.size(0)):
        t = batch_out[i]
        n = int(t.numel())
        if n > 0:
            flat.extend(t.detach().cpu().tolist())
        lens.append(n)
    return flat, lens


def flatten_idx_slice(dataset: IndexedDataset, start_doc: int, end_doc: int) -> Tuple[List[int], List[int]]:
    """Read a slice of IndexedDataset documents and return (flat_ids, lens_per_document)."""
    # dataset is now document-level; document boundaries in document_indices
    doc_indices = dataset.document_indices
    start_seq = int(doc_indices[start_doc])
    end_seq = int(doc_indices[end_doc])
    sequences = dataset[start_seq:end_seq]
    flat = []
    lens = []
    for seq in sequences:
        arr = seq
        n = int(arr.size)
        if n > 0:
            flat.extend(arr.tolist())
        lens.append(n)
    return flat, lens


def main():
    parser = argparse.ArgumentParser(description="Validate predispatch outputs vs grouter")
    parser.add_argument("--output_prefix", required=True, type=str, help="Final merged output prefix (without _key suffix)")
    parser.add_argument("--key", default="text", type=str, help="JSON key used in predispatch")
    parser.add_argument("--data_path", required=True, type=str, help="Original gz jsonl file path")
    parser.add_argument("--grouter_ckpt", required=True, type=str)
    parser.add_argument("--tokenizer_path", required=True, type=str)
    parser.add_argument("--topk", type=int, default=6)
    parser.add_argument("--expert_num", type=int, default=None)
    parser.add_argument("--max_length", type=int, default=4096)
    parser.add_argument("--scoring_func", type=str, default=None)
    parser.add_argument("--vocab_size", type=int, default=None)
    parser.add_argument("--hidden_size", type=int, default=None)
    parser.add_argument("--construct_type", type=str, default=None)
    parser.add_argument("--grouter_config", type=str, default=None)
    parser.add_argument("--sample_doc_indices", type=str, default="0,1,2", help="Comma-separated doc line indices to sample")
    parser.add_argument("--check_coverage", action="store_true", help="Check if counts align with document count")
    parser.add_argument("--validate_tokenized", action="store_true", help="Validate tokenized data vs dispatch data alignment")
    parser.add_argument("--append_eod", action="store_true")
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
    grt, device = load_grouter(args, tokenizer)

    # 1) Read preprocessed output (already merged on rank0)
    merged_prefix = f"{args.output_prefix}_{args.key}_dispatch_ids"
    ds = IndexedDataset(merged_prefix)
    
    # 2) If tokenized data validation is enabled, also read tokenized data
    tokenized_ds = None
    if args.validate_tokenized:
        try:
            tokenized_prefix = f"{args.output_prefix}_{args.key}_tokenized"
            tokenized_ds = IndexedDataset(tokenized_prefix)
            print(f"✓ Successfully loaded tokenized dataset: {tokenized_prefix}")
        except Exception as e:
            print(f"✗ Failed to load tokenized dataset: {e}")
            print("Make sure predispatch.py was run with both dispatch and tokenized data generation")
            return

    # 3) Read original C4 specified lines
    sample_ids = [int(x) for x in args.sample_doc_indices.split(',') if x.strip() != ""]
    sample_texts = read_c4_lines(args.data_path, sample_ids)

    # 4) For each sample document: split by sentences, run grouter, compare with corresponding document slice in IndexedDataset
    all_match = True
    all_tokenized_match = True
    
    for idx, text in zip(sample_ids, sample_texts):
        sentences = split_to_sentences(text)
        grouter_flat, grouter_lens = run_grouter_on_sentences(grt, tokenizer, sentences, device)

                    # Get the sequence range of the idx-th document from IndexedDataset
        ds_flat, ds_lens = flatten_idx_slice(ds, idx, idx + 1)

        # Validate if grouter output matches predispatch dispatch data
        same = grouter_flat == ds_flat
        print(f"doc {idx}: dispatch_ids_equal={same}, total_dispatch_tokens={len(ds_flat)}")
        if not same:
            all_match = False
            
        # Validate alignment between tokenized data and dispatch data
        if args.validate_tokenized and tokenized_ds is not None:
            # Get corresponding tokenized data
            tokenized_flat, tokenized_lens = flatten_idx_slice(tokenized_ds, idx, idx + 1)

            # Check if tokenized data matches the result of re-tokenizing the original text
            # Re-tokenize the original text
            re_tokenized = tokenizer(
                sentences,
                return_tensors="pt",
                padding=False,
                truncation=True,
                max_length=args.max_length,
                add_special_tokens=False,
            )

            re_tokenized_flat = []
            for i in range(re_tokenized["input_ids"].size(0)):
                sentence_ids = re_tokenized["input_ids"][i].tolist()
                if len(sentence_ids) > 0:
                    re_tokenized_flat.extend(sentence_ids)
                
            # Check if token count matches (dispatch data should match tokenized data length)
            if args.append_eod:
                token_count_match = (len(tokenized_flat) - 1) * args.topk == len(ds_flat)

                # Check if the last token is EOS token
                has_eos = tokenized_flat[-1] == tokenizer.eos_token_id if tokenized_flat else False
                re_tokenized_match = re_tokenized_flat == tokenized_flat[:-1]
                print(f"  tokenized: count_match={token_count_match}, has_eos={has_eos}, token_match={re_tokenized_match}")

            else:
                token_count_match = len(tokenized_flat) == (len(ds_flat) * args.topk)

                re_tokenized_match = re_tokenized_flat == tokenized_flat
                print(f"  tokenized: count_match={token_count_match}, token_match={re_tokenized_match}")
                    
            if not token_count_match:
                all_tokenized_match = False
                    

    # 5) Output validation results
    print("\n" + "="*60)
    print("VALIDATION RESULTS:")
    print("="*60)
    
    if all_match:
        print("All sampled documents match between grouter outputs and predispatch dispatch IndexedDataset.")
    else:
        print("Found mismatches in dispatch data. Inspect logs above.")
    
    if args.validate_tokenized:
        if all_tokenized_match:
            print("All tokenized data validation passed.")
        else:
            print("Found issues in tokenized data validation. Inspect logs above.")
    
    # 6) Optional coverage check: whether IndexedDataset document count matches original line count
    if args.check_coverage:
        print("\nCOVERAGE CHECK:")
        print("-" * 30)
        # Count original C4 document lines
        total_docs = 0
        with gzip.open(args.data_path, 'rt', encoding='utf-8') as fin:
            for _ in fin:
                total_docs += 1
        
        ds_docs = int(ds.document_indices.shape[0] - 1)
        print(f"dispatch data: original_docs={total_docs}, indexed_docs={ds_docs}, equal={total_docs == ds_docs}")
        
        if args.validate_tokenized and tokenized_ds is not None:
            tokenized_docs = int(tokenized_ds.document_indices.shape[0] - 1)
            print(f"tokenized data: original_docs={total_docs}, indexed_docs={tokenized_docs}, equal={total_docs == tokenized_docs}")
            
            # Check if dispatch and tokenized data document counts are consistent
            doc_count_match = ds_docs == tokenized_docs
            print(f"data consistency: dispatch_docs={ds_docs}, tokenized_docs={tokenized_docs}, equal={doc_count_match}")


if __name__ == "__main__":
    main()


