import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

from dotenv import load_dotenv

_ = load_dotenv()

import argparse
import math
import struct
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import transformers
from numpy.typing import DTypeLike
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

transformers.logging.set_verbosity_error()
from datasets import Dataset, load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils.general import extract_hidden_states_ids
from src.utils.utils import num_layers, replace_last_norm, set_seed


def save_representations_binary(
    output: list[tuple[int, int, int, np.ndarray]], 
    output_path: Path,
    flags: str = 'wb',
):
    with output_path.open(flags) as f:
        for sample_idx, start_idx, pos_idx, vec in output:
            # Ensure vector is float32 numpy array
            vec = np.asarray(vec, dtype=np.float32)
            
            # Pack the 3 integers (int32) + vector (float32[d])
            f.write(struct.pack('iii', sample_idx, start_idx, pos_idx))
            f.write(vec.tobytes())

def load_representations_binary(
    output_path: Path, 
    d: int,
    flags: str = 'rb'
) -> list[tuple[int, int, int, np.ndarray]]:
    record_size = 12 + 4 * d  # 3 * int32 + d * float32
    output = []
    with output_path.open(flags) as f:
        while chunk := f.read(record_size):
            sample_idx, start_idx, pos_idx = struct.unpack('iii', chunk[:12])
            vec = np.frombuffer(chunk[12:], dtype=np.float32, count=d)
            output.append((sample_idx, start_idx, pos_idx, vec))
    return output


def extract_multiple_hidden_states(
    llm: AutoModelForCausalLM,
    input_ids: torch.LongTensor,
    layers: list[int]
):
    with torch.no_grad():
        outputs = llm(
            input_ids=input_ids,
            output_hidden_states=True
        )

        representations = {
            layer_idx: outputs.hidden_states[layer_idx].detach().cpu()
            for layer_idx in layers
        }

    return representations

def extract_last_token_representations(
    sample_ids: list[int],
    start_ids: list[int],
    token_ids: list[torch.LongTensor], 
    llm: AutoModelForCausalLM, 
    layers: list[int],
    minimum_context_length: int = 100,
    context_increase_step: int = 40,
    seed: int = 8,
) -> dict[int, list[tuple[int, int, int, np.ndarray]]]:

    set_seed(seed)

    # Pad token sequences to the same length - shape: [n, max_seq_len]
    token_ids_padded: torch.LongTensor = pad_sequence(
        token_ids,  # type: ignore
        batch_first=True, 
        padding_value=0
    ).type(dtype=torch.long).to(llm.device)  # type: ignore 

    # Extract hidden states from the model at the specified layer
    representations = extract_multiple_hidden_states(
        llm, token_ids_padded, layers
    )  # shape: [n, max_seq_len, d]

    # Recover the original sequence lengths
    seq_lengths = [t.size(0) for t in token_ids]

    output = defaultdict(list)

    for layer_idx, H_batched_padded in representations.items():
        for i, (sample_idx, start_idx, seq_len) in enumerate(zip(sample_ids, start_ids, seq_lengths)):
            H = H_batched_padded[i, :seq_len, :]  # shape: [seq_len_i, d]
            
            output[layer_idx].extend([
                (sample_idx, start_idx, start_idx + j, H[j].numpy())
                for j in range(minimum_context_length - 1, seq_len, context_increase_step)
            ])

    return output

def extract_dataset_representations(
    dataset_path: Path,
    output_path: Path,
    text_column: str,

    llm: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,

    layers: list[int],
    prompt_tokens: int,
    max_prompts: int,
    batch_size: int,
    
    minimum_context_length: int = 100,
    context_increase_step: int = 40,
    seed: int = 8,
):
    set_seed(seed)
    rng = np.random.default_rng(seed)

    ds: Dataset = load_from_disk(str(dataset_path)) # type: ignore

    # we’ll take at most this many *eligible* samples
    total_to_take = min(max_prompts, len(ds))
    taken_so_far = 0

    # Basic guardrails
    if prompt_tokens <= 0:
        raise ValueError('prompt_tokens must be > 0')
    if batch_size <= 0:
        raise ValueError('batch_size must be > 0')

    # Truncation budget
    max_len = min(getattr(tokenizer, 'model_max_length', 1000) or 1000, 1000)

    first_write = True
    num_batches = math.ceil(total_to_take / batch_size)

    for b in tqdm(range(num_batches), desc='Processing batches'):
        if taken_so_far >= total_to_take:
            break

        start_idx_ds = b * batch_size
        end_idx_ds = min((b + 1) * batch_size, len(ds))
        # Cap by remaining prompts we still want
        remaining = total_to_take - taken_so_far
        if (end_idx_ds - start_idx_ds) > remaining:
            end_idx_ds = start_idx_ds + remaining

        batch = ds.select(range(start_idx_ds, end_idx_ds))
        texts: list[str] = list(batch[text_column])

        # Tokenize without special tokens; truncate to max_len
        enc = tokenizer(
            texts,
            add_special_tokens=False,
            truncation=True,
            max_length=max_len,
            return_attention_mask=False,
        ) # type: ignore
        input_ids_list: list[list[int]] = enc['input_ids']

        # Build per-sample random contiguous windows of length prompt_tokens
        token_ids: list[torch.LongTensor] = []
        start_ids: list[int] = []
        sample_ids: list[int] = []

        for local_i, ids in enumerate(input_ids_list):
            L = len(ids)
            if L < prompt_tokens:
                continue  # skip too-short sample

            # choose start uniformly among valid windows
            s = int(rng.integers(0, L - prompt_tokens + 1))
            window = ids[s : s + prompt_tokens]

            token_ids.append(torch.tensor(window, dtype=torch.long)) # type: ignore
            start_ids.append(s)
            sample_ids.append(start_idx_ds + local_i)  # index w.r.t. full dataset

        if not token_ids:
            # no eligible samples in this batch; continue
            continue

        # Optionally cap within this batch if we already hit the global limit
        if len(token_ids) > remaining:
            token_ids = token_ids[:remaining]
            start_ids = start_ids[:remaining]
            sample_ids = sample_ids[:remaining]

        output_dict = extract_last_token_representations(
            sample_ids=sample_ids,
            start_ids=start_ids,
            token_ids=token_ids,
            llm=llm,
            layers=layers,
            minimum_context_length=minimum_context_length,
            context_increase_step=context_increase_step,
            seed=seed,
        )

        # Save to disk; first batch overwrites, subsequent batches append
        for layer_idx, output in output_dict.items():
            save_representations_binary(
                output=output,
                output_path=output_path.with_name(f'{output_path.name}-{layer_idx}'),
                flags='wb' if first_write else 'ab',
            )
        first_write = False

        taken_so_far += len(token_ids)

    if taken_so_far == 0:
        raise ValueError(
            'No samples were long enough to satisfy `prompt_tokens`. '
            f'Try reducing prompt_tokens (currently {prompt_tokens}) or check your dataset.'
        )


def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-dir', '--data-dir', 
        type=str, default='./data',
        help='Path to the directory containing the datasets.'
    )
    parser.add_argument(
        '--datasets', 
        type=str, nargs='+', 
        default=[
            'wikipedia', 
            'colossal_clean_crawled_corpus', 
            'arxiv_pile', 
            'github_python'
        ],
        help='Name of datasets to use. Should be subdirectories of `--data-dir`.'
    )
    parser.add_argument(
        '-o', '--output-dir', 
        type=str, required=True,
        help='Name of the output subdirectory to be created.'
    )
    parser.add_argument(
        '--text-column',
        type=str, default='text',
        help='Name of the column containing the text in the datasets'
    )

    parser.add_argument(
        '--seed', 
        type=int, default=8,
        help='Random seed to use.'
    )

    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--quantize',
        action='store_true',
        help='Flag for whether to quantize the model or not'
    )
    parser.add_argument(
        '--float16',
        action='store_true',
        help='Flag for whether or not to use half precision when not quantizing the model.'
    )
    parser.add_argument(
        '--layer', 
        type=int, default=-1,
        help='Layer to extract embeddings on. Negative values index the model from the last layer.'
    )

    parser.add_argument(
        '--semi',
        action='store_true',
        help='Flag for whether to extract representations for the first, middle and last layers.'
    )
    parser.add_argument(
        '--full',
        action='store_true',
        help='Flag for whether to extract representations for all layers.'
    )

    parser.add_argument(
        '-n', '--max-prompts', 
        type=int, default=10,
        help='Maximum amount of prompts to use.'
    )
    parser.add_argument(
        '--prompt-tokens', 
        type=int, default=500,
        help='Number of tokens per prompt to use.'
    )
    parser.add_argument(
        '--batch-size', 
        type=int, default=64,
        help='Batch Size for the forward pass of the LLM.'
    )

    parser.add_argument(
        '-mcl', '--minimum_context_length', 
        type=int, default=100,
        help='Minimum length of extracted representation for sub-contexts.'
    )
    parser.add_argument(
        '-cis', '--context-increase-step', 
        type=int, default=40,
        help='Increment to the length of the extracted representation for each sub-context.'
    )

    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()

    args.seed    *= os.getpid()
    seed          = args.seed
    model_id      = args.id
    load_in_8bit  = args.quantize
    use_float16   = args.float16
    layer_idx     = args.layer
    extract_semi  = args.semi
    extract_full  = args.full

    if extract_semi and extract_full:
        extract_semi = False

    model_name = model_id.split('/')[-1]

    data_dir                 = Path(args.data_dir)
    dataset_names: list[str] = args.datasets
    output_subfolder         = data_dir / args.output_dir / model_name
    text_column              = args.text_column

    prompt_tokens = args.prompt_tokens
    max_prompts   = args.max_prompts
    batch_size    = args.batch_size

    minimum_context_length = args.minimum_context_length
    context_increase_step  = args.context_increase_step

    total_layers = num_layers(model_id)
    if layer_idx < 0:
        layer_idx = total_layers + layer_idx + 1

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Ensure output dir exists
    output_subfolder.mkdir(parents=True, exist_ok=True)

    # --- Tokenizer ---
    tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    # --- Model ---
    model_kwargs = dict(pretrained_model_name_or_path=model_id, device_map='auto', local_files_only=True)
    if load_in_8bit:
        model_kwargs['load_in_8bit'] = True
    else:
        model_kwargs['torch_dtype'] = torch.bfloat16 if use_float16 else torch.float32

    model = AutoModelForCausalLM.from_pretrained(**model_kwargs)
    model.eval()

    replace_last_norm(model_id, model)

    if not load_in_8bit:
        # In 8-bit, device_map handles placement; in full precision we can move the model explicitly
        model.to(device)

    for p in model.parameters():
        p.requires_grad = False

    config = getattr(model.config, 'text_config', None) or model.config
    hidden_size = (
        getattr(config, 'hidden_size', None) or 
        getattr(config, 'd_model', None)
    )

    if hidden_size is None:
        raise ValueError('Could not determine hidden state dimension for this model.')


    log_file = output_subfolder / "run_args.txt"
    with log_file.open("w") as f:
        print("Running with arguments:")
        f.write("Running with arguments:\n")

        for k, v in vars(args).items():
            line = f"  {k}: {v}"
            print(line)
            f.write(line + "\n")

        extra_lines = [
            f"Using device: {device}",
            f"Number of layers: {total_layers}",
            f"Hidden Size: d = {hidden_size}"
        ]
        for line in extra_lines:
            print(line)
            f.write(line + "\n")


    # --- Process datasets ---
    layers = [layer_idx]
    if extract_semi:
        layers = [1, total_layers // 2 + 1, total_layers]
    elif extract_full:
        layers = list(range(1, total_layers + 1))

    for dataset_name in dataset_names:
        ds_path = data_dir / dataset_name
        out_file = (output_subfolder / dataset_name).with_suffix('.bin')
        out_file.parent.mkdir(parents=True, exist_ok=True)

        print(f'\nExtracting from `{ds_path}` -> `{out_file}`')
        extract_dataset_representations(
            dataset_path=ds_path,
            output_path=out_file,
            text_column=text_column,

            llm=model,
            tokenizer=tokenizer,

            layers=layers,
            prompt_tokens=prompt_tokens,
            max_prompts=max_prompts,
            batch_size=batch_size,

            minimum_context_length=minimum_context_length,
            context_increase_step=context_increase_step,

            seed=seed,
        )
        print(f'Saved: {out_file}')
        