import os
import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

import argparse
import heapq
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Literal, Union

import pandas as pd
import torch
from tqdm.auto import tqdm

from src.ablations.dataset_collisions import load_representations_binary
from src.utils.stats import RunningStats


def group_by_bin_number(root: Path) -> dict[int, list[str]]:
    pat = re.compile(r"\.bin-(\d+)$")
    groups = defaultdict(list)

    for p in root.iterdir():
        if p.is_file():
            m = pat.search(p.name)
            if m:
                groups[int(m.group(1))].append(
                    str(p.name).replace(m.group(1), '')[:-5]
                )

    return {k: sorted(v) for k, v in sorted(groups.items())}


def find_closest(
    data_dir: Path,
    datasets: list[str],
    model_name: str,
    run_args_name: str,
    norm_ord: int = 2,
    norm_type: Union[Literal['sum'], Literal['mean']] = 'sum',
    top_k: int = 100,
):
        
    DATA_DIR  = data_dir / model_name
    ARGS_PATH = DATA_DIR / run_args_name

    with ARGS_PATH.open('r') as f:
        content = f.read()
    
    match = re.search(r"Hidden Size:\s*d\s*=\s*(\d+)", content)
    
    if not match:
        raise Exception(f'Hidden size not found in `{ARGS_PATH}`!')
    
    d = int(match.group(1))

    paths_dict = group_by_bin_number(DATA_DIR)

    zeros = []
    rows = []
    stats_rows: list[dict[str, Any]] = []
    for layer, datasets in paths_dict.items():
        stats = RunningStats()

        print(f'Computing distances for layer {layer}...')

        paths = {
            dataset: DATA_DIR / f'{dataset}.bin-{layer}'
            for dataset in datasets
        }

        print('\tLoading Samples...', end=' ', flush=True)
        samples = [
            [
                (name, *sample)
                for sample in load_representations_binary(path, d=d, flags='rb')
            ]
            for name, path in paths.items()
        ]
        samples = [item for sublist in samples for item in sublist]
        print('Done')

        print(f'\tEmbeddings computed: {len(samples):,}')
        # samples = samples[:20_000]
        
        print('\tCreating arrays...', end=' ', flush=True)
        datasets, sample_ids, start_ids, end_ids, arrays = zip(*samples) # type: ignore
        print(f'Done')

        print('\tCreating tesnors...', end=' ', flush=True)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        arrays = [torch.tensor(a, dtype=torch.float32, device=device) for a in arrays]
        arrays = torch.stack(arrays, dim=0)
        print(f'Done')

        for i in (bar := tqdm(range(len(arrays) - 1), desc=f'{model_name} - {d}')):
            # "Pop" the first array in order to compare with all the rest
            array = arrays[0]        # shape: (D)
            arrays = arrays[1 :]     # shape: (N - i - 1, D)

            diffs = torch.linalg.norm(
                arrays - array.unsqueeze(0), 
                ord=norm_ord, 
                dim=1
            )

            if norm_type == 'mean':
                diffs /= array.size(1)

            del array

            zero_mask = diffs.abs() < 1e-16
            nz = diffs[~zero_mask]
            stats.update_batch(nz)

            zero_idx = torch.nonzero(diffs.abs() < 1e-16, as_tuple=True)[0]
            zeros.extend([(i, i + 1 + x.item()) for x in zero_idx])

            del diffs, nz

            bar.set_postfix({
                'min': stats.min,
                'mean': stats.mean,
                '# zeros': len(zeros),
            })


        stats_rows.append(stats.finalize_row(layer=layer))

        for i, j in zeros:
            rows.append({
                'distance': 0.0,

                'i': i,
                'i_dataset': datasets[i],
                'i_sample_idx': sample_ids[i],
                'i_start_idx': start_ids[i],
                'i_end_idx': end_ids[i],

                'j': j,
                'j_dataset': datasets[j],
                'j_sample_idx': sample_ids[j],
                'j_start_idx': start_ids[j],
                'j_end_idx': end_ids[j],
            })

    out_csv = DATA_DIR / f'zeros.csv'
    pd.DataFrame(rows).to_csv(out_csv, index=False)

    out_csv = DATA_DIR / f'layer-stats.csv'
    df = pd.DataFrame(stats_rows, columns=["layer", "count", "mean", "std", "min", "max"])
    df.sort_values(by=["layer"], inplace=True)
    df.to_csv(out_csv, index=False)


def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-dir', '--data-dir', 
        type=str, default='./data/dataset_exp',
        help='Path to the directory containing the dataset embeddings'
    )
    parser.add_argument(
        '--datasets', 
        type=str, nargs='+', 
        default=[
            'wikipedia', 
            'colossal_clean_crawled_corpus', 
            'arxiv_pile', 
            'github_python'
        ],
        help='Name of datasets to use. Should be subdirectories of `--data-dir`.'
    )
    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--run-args',
        type=str, default='run_args.txt',
        help='Name of the `run_args` text file.'
    )
    parser.add_argument(
        '--ord',
        type=int, default=2,
        help='Order of p-norm to use.'
    )
    parser.add_argument(
        '--norm-type',
        type=str, choices=['sum', 'mean'], 
        default='sum',
        help='Mean or Sum differences.'
    )
    parser.add_argument(
        '-k', '--top-k', 
        type=int, default=100,
        help='Top-K closest pairs to store.'
    )
    

    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()

    data_dir = Path(args.data_dir)
    dataset_names: list[str] = args.datasets

    model_id = args.id
    model_name = model_id.split('/')[-1]
    run_args = args.run_args

    norm_ord  = args.ord
    norm_type = args.norm_type

    top_k = args.top_k

    find_closest(
        data_dir, dataset_names,
        model_name, run_args, 
        norm_ord, norm_type, 
        top_k
    )

# ./scripts/find_close.sh --time 12:00:00 --top-k 100000