import os
import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

import argparse
import heapq
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Literal, Union

import pandas as pd
import torch
from tqdm.auto import tqdm

from src.ablations.exhaustive_collisions import load_representations_binary
from src.utils.stats import RunningStats


def one_per_bin_number(root: Path) -> dict[int, str]:
    """
    Map each '<prefix>.bin-<num>' file in `root` to {num: prefix}.
    Example: 'suffix_one_exhaustive.bin-7' -> {7: 'suffix_one_exhaustive'}.
    Files without the '-<num>' suffix are ignored.
    """
    pat = re.compile(r"^(?P<prefix>.+)\.bin-(?P<num>\d+)$")
    out: dict[int, str] = {}

    for p in root.iterdir():
        if p.is_file():
            m = pat.match(p.name)
            if m:
                num = int(m.group("num"))
                prefix = m.group("prefix")
                out[num] = prefix  # last one wins if duplicates

    return dict(sorted(out.items()))



def find_closest(
    data_dir: Path,
    model_name: str,
    run_args_name: str,
    norm_ord: int = 2,
    norm_type: Union[Literal['sum'], Literal['mean']] = 'sum',
    top_k: int = 100,
):
        
    DATA_DIR  = data_dir / model_name
    ARGS_PATH = DATA_DIR / run_args_name

    with ARGS_PATH.open('r') as f:
        content = f.read()
    
    match = re.search(r"Hidden Size:\s*d\s*=\s*(\d+)", content)
    
    if not match:
        raise Exception(f'Hidden size not found in `{ARGS_PATH}`!')
    
    d = int(match.group(1))

    paths_dict = one_per_bin_number(DATA_DIR)

    zeros = []
    rows = []
    stats_rows: list[dict[str, Any]] = []
    for sample_idx, file in paths_dict.items():
        stats = RunningStats()

        print(f'Computing distances for sample {sample_idx}...')
        
        file_path = DATA_DIR / f'{file}.bin-{sample_idx}'

        print('\tLoading Samples...', end=' ', flush=True)
        samples = load_representations_binary(file_path, d=d, flags='rb')
        print('Done')

        print(f'\tEmbeddings computed: {len(samples):,}')
        
        print('\tCreating arrays...', end=' ', flush=True)
        sample_ids, token_ids, arrays = zip(*samples) # type: ignore
        print(f'Done')

        print('\tCreating tesnors...', end=' ', flush=True)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        arrays = [torch.tensor(a, dtype=torch.float32, device=device) for a in arrays]
        arrays = torch.stack(arrays, dim=0)
        print(f'Done')

        for i in (bar := tqdm(range(len(arrays) - 1), desc=f'{model_name} - {d}')):
            # "Pop" the first array in order to compare with all the rest
            array = arrays[0]        # shape: (D)
            arrays = arrays[1 :]     # shape: (N - i - 1, D)

            diffs = torch.linalg.norm(
                arrays - array.unsqueeze(0), 
                ord=norm_ord, 
                dim=1
            )

            if norm_type == 'mean':
                diffs /= array.size(1)

            del array

            zero_mask = diffs.abs() < 1e-16
            nz = diffs[~zero_mask]
            stats.update_batch(nz)

            zero_idx = torch.nonzero(diffs.abs() < 1e-16, as_tuple=True)[0]
            zeros.extend([(i, i + 1 + x.item()) for x in zero_idx])

            del diffs, nz

            bar.set_postfix({
                'min': stats.min,
                'mean': stats.mean,
                '# zeros': len(zeros),
            })


        stats_rows.append(stats.finalize_row(sample_idx=sample_idx))

        for i, j in zeros:
            rows.append({
                'distance': 0.0,

                'i': i,
                'i_sample_idx': sample_ids[i],
                'i_token_idx': token_ids[i],

                'j': j,
                'j_sample_idx': sample_ids[j],
                'j_token_idx': token_ids[j],
            })

    out_csv = DATA_DIR / f'zeros.csv'
    pd.DataFrame(rows).to_csv(out_csv, index=False)

    out_csv = DATA_DIR / f'sample-stats.csv'
    df = pd.DataFrame(stats_rows, columns=["sample_idx", "count", "mean", "std", "min", "max"])
    df.sort_values(by=["sample_idx"], inplace=True)
    df.to_csv(out_csv, index=False)


def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-dir', '--data-dir', 
        type=str, default='./data/exhaustive',
        help='Path to the directory containing the dataset embeddings'
    )
    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--run-args',
        type=str, default='run_args.txt',
        help='Name of the `run_args` text file.'
    )
    parser.add_argument(
        '--ord',
        type=int, default=2,
        help='Order of p-norm to use.'
    )
    parser.add_argument(
        '--norm-type',
        type=str, choices=['sum', 'mean'], 
        default='sum',
        help='Mean or Sum differences.'
    )
    parser.add_argument(
        '-k', '--top-k', 
        type=int, default=100,
        help='Top-K closest pairs to store.'
    )
    

    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()

    data_dir = Path(args.data_dir)

    model_id = args.id
    model_name = model_id.split('/')[-1]
    run_args = args.run_args

    norm_ord  = args.ord
    norm_type = args.norm_type

    top_k = args.top_k

    find_closest(
        data_dir,
        model_name, run_args, 
        norm_ord, norm_type, 
        top_k
    )