import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys
from pathlib import Path
import argparse

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

from dotenv import load_dotenv

_ = load_dotenv()

import pandas as pd
import transformers
from datasets import Dataset, load_from_disk
from transformers import AutoTokenizer

transformers.logging.set_verbosity_error()


def verify_same_context(
    tokenizer: AutoTokenizer,

    i_dataset: Path,
    i_sample_idx: int,
    i_start_idx: int,
    i_end_idx: int,

    j_dataset: Path,
    j_sample_idx: int,
    j_start_idx: int,
    j_end_idx: int,

    text_column: str = 'text',
    verbose: bool = False
):
    ds1: Dataset = load_from_disk(str(i_dataset))
    ds2: Dataset = load_from_disk(str(j_dataset))

    sample1 = ds1[i_sample_idx][text_column]
    sample2 = ds2[j_sample_idx][text_column]

    range1 = [i_start_idx, i_end_idx]
    range2 = [j_start_idx, j_end_idx]

    enc1 = tokenizer(
        sample1,
        add_special_tokens=False,
        truncation=True,
        max_length=2048,
        return_attention_mask=False,
    )  # type: ignore
    input_ids_list1: list[int] = enc1['input_ids'][range1[0]:range1[1]]

    enc2 = tokenizer(
        sample2,
        add_special_tokens=False,
        truncation=True,
        max_length=2048,
        return_attention_mask=False,
    )  # type: ignore
    input_ids_list2: list[int] = enc2['input_ids'][range2[0]:range2[1]]

    are_same = False
    for i, j in zip(input_ids_list1, input_ids_list2):
        are_same = are_same or (i == j)

        if verbose:
            dec1 = tokenizer.decode(i)
            dec2 = tokenizer.decode(j)
            print(f'{i:6d} - {j:6d} | {dec1:>10s} - {dec2:>10s}')

    if verbose:
        print(
            f'Contexts for sample ids [{i_sample_idx:5d}/{j_sample_idx:5d}] '
            f'are {"" if are_same else "not "}the same!'
        )

    return are_same


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Verify whether token spans from two datasets encode the same context."
    )

    # Initial variables with your shown defaults
    parser.add_argument(
        "--data-root",
        type=Path,
        default=Path("./data"),
        help="Root directory where datasets live (default: ./data)",
    )
    parser.add_argument(
        "--csv-path",
        type=Path,
        default="data/dataset_exp/Mathstral-7B-v0.1/top100-closest.csv",
        help="CSV with candidate pairs (default: data/dataset_exp/Mathstral-7B-v0.1/top100-closest.csv)",
    )
    parser.add_argument(
        "--model-id",
        type=str,
        default="mistralai/Mathstral-7B-v0.1",
        help="Tokenizer model id (default: mistralai/Mathstral-7B-v0.1)",
    )
    # Keep default True while allowing --no-verbose to turn it off (Python 3.9+: BooleanOptionalAction)
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Print token-by-token comparisons",
    )
    parser.add_argument(
        "--text-column",
        type=str,
        default="text",
        help="Text column name in the dataset (default: text)",
    )
    # Optional: make the distance threshold configurable; default matches your code
    parser.add_argument(
        "--distance-threshold",
        type=float,
        default=1e-12,
        help="Max distance to keep pairs from CSV (default: 1e-16)",
    )
    parser.add_argument(
        "--pattern",
        type=str,
        default="*.csv",
        help="Glob pattern if `--csv-path` is a dir.",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    DATA_ROOT: Path = args.data_root
    csv_path: Path = args.csv_path
    model_id: str = args.model_id
    verbose: bool = args.verbose
    text_column: str = args.text_column
    distance_threshold: float = args.distance_threshold
    pattern: str = args.pattern


    tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token


    csvs = sorted(list(csv_path.glob(pattern))) if csv_path.is_dir() else [csv_path]

    true_collision_all = False
    for csv_path in csvs:
        print(f'Veryfing `{csv_path}`')

        try:
            df = pd.read_csv(csv_path)
        except:
            continue
        
        df = df[df['distance'] <= distance_threshold]

        true_collision = False
        for _, row in df.iterrows():
            # Ensure integer indices if CSV is read as floats
            i_sample_idx = int(row['i_sample_idx'])
            j_sample_idx = int(row['j_sample_idx'])
            i_start_idx = int(row['i_start_idx'])
            j_start_idx = int(row['j_start_idx'])
            i_end_idx = int(row['i_end_idx'])
            j_end_idx = int(row['j_end_idx'])

            same = verify_same_context(
                tokenizer=tokenizer,
                i_dataset=DATA_ROOT / row['i_dataset'],
                j_dataset=DATA_ROOT / row['j_dataset'],
                i_sample_idx=i_sample_idx,
                j_sample_idx=j_sample_idx,
                i_start_idx=i_start_idx,
                j_start_idx=j_start_idx,
                i_end_idx=i_end_idx,
                j_end_idx=j_end_idx,
                text_column=text_column,
                verbose=verbose,
            )
            # Your original logic: set true_collision to True if ANY pair is NOT the same
            true_collision = true_collision or (not same)
            true_collision_all = true_collision_all or (not same)

        if verbose:
            print('=' * 100, end='\n')

        print(f'\tPossible Collisions: {len(df)} - Any true collisions: {true_collision}')

    print(f"\nAny true collision (non-matching contexts) found: {true_collision_all}")
    
