import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

from dotenv import load_dotenv

_ = load_dotenv()

import argparse
import struct
from pathlib import Path
from typing import Iterable, Iterator

import numpy as np
import pandas as pd
import torch
import transformers
from numpy.typing import DTypeLike
from tqdm import tqdm

transformers.logging.set_verbosity_error()
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils.general import extract_hidden_states_ids
from src.utils.utils import set_seed, all_token_ids_strict, num_layers, replace_last_norm


def save_representations_binary(
    output: list[tuple[int, int, np.ndarray]],
    output_path: Path,
    flags: str = "wb",
) -> None:
    """
    Writes records as:
      int32 sample_idx
      int32 token_idx
      float32[d] vec
    """
    with output_path.open(flags) as f:
        for sample_idx, token_idx, vec in output:
            vec = np.asarray(vec, dtype=np.float32)
            f.write(struct.pack("ii", int(sample_idx), int(token_idx)))
            f.write(vec.tobytes())

def load_representations_binary(
    output_path: Path,
    d: int,
    flags: str = "rb",
) -> list[tuple[int, int, np.ndarray]]:
    """
    Reads records of the form written above.
    """
    record_size = 8 + 4 * d  # 2 * int32 + d * float32
    output: list[tuple[int, int, np.ndarray]] = []
    with output_path.open(flags) as f:
        while chunk := f.read(record_size):
            sample_idx, token_idx = struct.unpack("ii", chunk[:8])
            vec = np.frombuffer(chunk[8:], dtype=np.float32, count=d).copy()
            output.append((int(sample_idx), int(token_idx), vec))
    return output


def _batched_iterator(items: Iterable[int], batch_size: int) -> Iterator[list[int]]:
    batch: list[int] = []
    for x in items:
        batch.append(x)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if batch:
        yield batch

def _compose_batch_sequences(
    base_ids: list[int],
    new_token_ids: list[int],
    is_prefix: bool,
) -> torch.LongTensor:
    """
    Given a single base sequence (list[int]) and a list of new token IDs, build a 2D LongTensor
    of shape (batch_size, len(base)+1). Each row is either:
      [t, *base]   (if prefix)
      [*base, t]   (if suffix)
    """
    L = len(base_ids)
    B = len(new_token_ids)
    out = torch.empty((B, L + 1), dtype=torch.long)
    base = torch.tensor(base_ids, dtype=torch.long)
    if is_prefix:
        out[:, 0] = torch.tensor(new_token_ids, dtype=torch.long)
        out[:, 1:] = base.unsqueeze(0).expand(B, L)
    else:
        out[:, :-1] = base.unsqueeze(0).expand(B, L)
        out[:, -1] = torch.tensor(new_token_ids, dtype=torch.long)
    
    return out # type: ignore


@torch.no_grad()
def extract_sample_representations(
    csv_sample_idx: int,

    dataset_path: Path,
    sample_idx: int,
    start_idx: int,
    end_idx: int,
    text_column: str,
    output_path: Path,

    model: AutoModelForCausalLM, 
    layer_idx: int,

    tokenizer: AutoTokenizer,
    include_special: bool,

    batch_size: int,
    is_prefix: bool,
    seed: int = 8,
):
    """
    For a single dataset sample, take its text as a base sequence. Then for every token id in the
    tokenizer vocab (optionally excluding special IDs), compose a sequence where that token is
    either *prefixed* or *suffixed* to the base sequence, run the model, and dump the hidden-state
    vector (at `layer_idx`) of the newly added token to a binary file alongside the token id.

    Also runs a forward pass of the base prompt alone and stores it with token_id = -1.
    Only the segment base_ids[start_idx:end_idx] is used as the base (handy to window long prompts).
    """
    set_seed(seed)
    device = getattr(model, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # Respect model_max_length but cap at 2048 unless model supports more and you choose to raise the cap.
    max_len = min(getattr(tokenizer, 'model_max_length', 2048) or 2048, 2048)

    # Load sample
    sample = load_from_disk(str(dataset_path)).to_list()[sample_idx]  # type: ignore
    prompt: str = sample[text_column]

    # Tokenize base prompt
    tokenized_prompt = tokenizer(
        prompt,
        add_special_tokens=False,
        truncation=True,
        max_length=max_len,
        return_attention_mask=False,
    )  # type: ignore

    full_ids: list[int] = tokenized_prompt["input_ids"]
    # Apply optional window
    start_idx = max(0, start_idx)
    end_idx = min(len(full_ids), end_idx if end_idx >= 0 else len(full_ids))
    base_ids = full_ids[start_idx:end_idx]
    
    if len(base_ids) >= max_len:
        base_ids = base_ids[:-1]

    if len(base_ids) == 0:
        raise ValueError("Selected window produced an empty base sequence. Check start_idx/end_idx.")
    
    # First, run the prompt-only pass and save with token_id = -1
    base_tensor: torch.LongTensor = torch.tensor(
        [base_ids], 
        dtype=torch.long, 
        device=device
    ) # type: ignore
    H_base = extract_hidden_states_ids(
        base_tensor, 
        model,  # type: ignore
        layer_idx=layer_idx,
        grad=False,
        batch=False
    )

    save_representations_binary(
        [(int(-1), int(-1), H_base.cpu().numpy()[-1])], 
        output_path=output_path, 
        flags="wb"
    )

    # Prepare vocab sweep
    possible_tokens = all_token_ids_strict(tokenizer, include_special=include_special)

    # Write representations for the sweep
    minmin = float('inf')
    for batch in (bar := tqdm(list(_batched_iterator(possible_tokens, batch_size)))):
        # Compose input_ids (B, L+1)
        input_ids: torch.LongTensor = _compose_batch_sequences(
            base_ids, 
            batch, 
            is_prefix=is_prefix
        ).to(device) # type: ignore


        H = extract_hidden_states_ids(
            input_ids, 
            model,  # type: ignore
            layer_idx=layer_idx,
            grad=False,
            batch=True
        ).cpu()[:, -1, :]  # (B, hidden)

        min_dist = torch.min(torch.pdist(H))
        if minmin > min_dist:
            minmin = min_dist.item()
            bar.set_postfix({'Min Dist': minmin})

        # Save
        triplets = [(int(csv_sample_idx), int(tok_id), H[i].numpy()) for i, tok_id in enumerate(batch)]
        save_representations_binary(
            triplets,
            output_path=output_path,
            flags="ab",
        )
                




def process_distance_csv(
    file_path: Path,
    n: int,
    distance_threshold: float = 1e-12
) -> pd.DataFrame:
    """
    Process a CSV file by filtering, splitting, sorting, and selecting top N entries.

    Parameters:
        file_path (str): Path to the input CSV file.
        n (int): Number of rows to return after processing.
        distance_threshold (float): Minimum distance threshold to keep a row.

    Returns:
        pd.DataFrame: Processed DataFrame with columns [distance, dataset, sample_idx, start_idx, end_idx]
    """
    # Load CSV
    df = pd.read_csv(file_path)

    # Step 1: Filter out rows with distance <= threshold
    df = df[df['distance'] > distance_threshold]

    # Step 2: Split each row into two: one for sample i and one for sample j
    df_i = df[['distance', 'i_dataset', 'i_sample_idx', 'i_start_idx', 'i_end_idx']].copy()
    df_i.columns = ['distance', 'dataset', 'sample_idx', 'start_idx', 'end_idx']

    df_j = df[['distance', 'j_dataset', 'j_sample_idx', 'j_start_idx', 'j_end_idx']].copy()
    df_j.columns = ['distance', 'dataset', 'sample_idx', 'start_idx', 'end_idx']

    # Combine and sort
    df_combined = pd.concat([df_i, df_j], ignore_index=True)
    df_unique = df_combined.drop_duplicates(
        subset=['dataset', 'sample_idx', 'start_idx', 'end_idx']
    ) # type: ignore
    df_sorted = df_unique.sort_values(by='distance', ascending=True) # type: ignore
    # df_sorted = df_sorted[df_sorted['dataset'] == 'wikipedia']

    # Step 3: Keep first n rows
    df_final = df_sorted.head(n)

    return df_final


def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-dir', '--data-dir', 
        type=str, default='./data',
        help='Path to the directory containing the datasets.'
    )
    parser.add_argument(
        '-i', '--input', 
        type=str, required=True,
        help=(
            'Path to CSV input file. ' +
            'Should have columns [`dataset`, `sample_idx`, `start_idx`, `end_idx`]. ' +
            '`dataset` should be a subdirectory of `--data-dir`.'
        )
    )
    parser.add_argument(
        '-o', '--output-dir', 
        type=str, required=True,
        help='Name of the output subdirectory to be created.'
    )
    parser.add_argument(
        '--text-column',
        type=str, default='text',
        help='Name of the column containing the text in the datasets'
    )

    parser.add_argument(
        '--seed', 
        type=int, default=8,
        help='Random seed to use.'
    )

    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--quantize',
        action='store_true',
        help='Flag for whether to quantize the model or not'
    )
    parser.add_argument(
        '--float16',
        action='store_true',
        help='Flag for whether or not to use half precision when not quantizing the model.'
    )
    parser.add_argument(
        '--layer', 
        type=int, default=-1,
        help='Layer to extract embeddings on. Negative values index the model from the last layer.'
    )

    parser.add_argument(
        '--exclude-special',
        action='store_true',
        help='Flag for whether or not to exclude special tokens.'
    )


    parser.add_argument(
        '-n', '--max-prompts', 
        type=int, default=10,
        help='Maximum amount of prompts to use.'
    )
    parser.add_argument(
        '--batch-size', 
        type=int, default=64,
        help='Batch Size for the forward pass of the LLM.'
    )
    parser.add_argument(
        '--is-prefix',
        action='store_true',
        help='Flag for whether the added token is a prefix or a suffix.'
    )
    parser.add_argument(
        "--distance-threshold",
        type=float,
        default=1e-12,
        help="Max distance to keep pairs from CSV (default: 1e-12)",
    )

    parser.add_argument(
        '--sample-idx', 
        type=int, default=-1,
        help='Pick a specific sample.'
    )

    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()

    args.seed    *= os.getpid()
    seed          = args.seed
    model_id      = args.id
    load_in_8bit  = args.quantize
    use_float16   = args.float16
    layer_idx     = args.layer

    include_special = not args.exclude_special

    model_name = model_id.split('/')[-1]

    data_dir         = Path(args.data_dir)
    dataset_path     = args.input
    output_subfolder = data_dir / args.output_dir / model_name
    text_column      = args.text_column

    max_prompts = args.max_prompts
    batch_size  = args.batch_size
    is_prefix   = args.is_prefix
    threshold   = args.distance_threshold

    sample_index = args.sample_idx

    total_layers = num_layers(model_id)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Ensure output dir exists
    out_file = (output_subfolder / f'{"prefix" if is_prefix else "suffix"}_one_exhaustive').with_suffix('.bin')
    out_file.parent.mkdir(parents=True, exist_ok=True)

    # --- Tokenizer ---
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=True,
        token=os.getenv('HUGGINGFACE_TOKEN') or True,  # True -> use HF token from cache/env if available
    )
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    # --- Model ---
    model_kwargs = dict(
        pretrained_model_name_or_path=model_id,
        device_map='auto',                    # let HF place modules; safer for 8bit
        trust_remote_code=True,
        token=os.getenv('HUGGINGFACE_TOKEN') or True,
    )
    if load_in_8bit:
        model_kwargs['load_in_8bit'] = True
    else:
        model_kwargs['torch_dtype'] = torch.bfloat16 if use_float16 else torch.float32
        print(model_kwargs['torch_dtype'])

    model = AutoModelForCausalLM.from_pretrained(**model_kwargs)
    model.eval()

    replace_last_norm(model_id, model)

    if not load_in_8bit:
        # In 8-bit, device_map handles placement; in full precision we can move the model explicitly
        model.to(device)

    for p in model.parameters():
        p.requires_grad = False

    config = getattr(model.config, 'text_config', None) or model.config
    hidden_size = (
        getattr(config, 'hidden_size', None) or 
        getattr(config, 'd_model', None)
    )

    if hidden_size is None:
        raise ValueError('Could not determine hidden state dimension for this model.')


    log_file = output_subfolder / "run_args.txt"
    with log_file.open("w") as f:
        print("Running with arguments:")
        f.write("Running with arguments:\n")

        for k, v in vars(args).items():
            line = f"  {k}: {v}"
            print(line)
            f.write(line + "\n")

        extra_lines = [
            f"Using device: {device}",
            f"Number of layers: {total_layers}",
            f"Hidden Size: d = {hidden_size}"
        ]
        for line in extra_lines:
            print(line)
            f.write(line + "\n")


    print(f'\nExtracting from `{dataset_path}` -> `{out_file}`')


    df = process_distance_csv(dataset_path, max_prompts, threshold)
    df.to_csv(output_subfolder / 'prompts.csv', index=False)

    for idx, (index, sample) in enumerate(df.iterrows()):

        if sample_index >= 0 and sample_index != idx:
            continue

        ds_path = data_dir / sample['dataset']

        sample_idx = sample['sample_idx']
        start_idx  = sample['start_idx']
        end_idx    = sample['end_idx']

        extract_sample_representations(
            csv_sample_idx=idx,
            dataset_path=ds_path,
            sample_idx=sample_idx,
            start_idx=start_idx,
            end_idx=end_idx,
            text_column=text_column,
            output_path=out_file.with_name(f'{out_file.name}-{idx}'),
            model=model,
            layer_idx=layer_idx,
            tokenizer=tokenizer,
            include_special=include_special,
            batch_size=batch_size,
            is_prefix=is_prefix,
            seed=seed
        )

    print(f'Saved: {out_file}')
        