import os
import sys

sys.path.append(os.getcwd())
sys.path.append('.')
sys.path.append('..')

import argparse
import heapq
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Literal, Union

import pandas as pd
import torch
from tqdm.auto import tqdm

from src.ablations.dataset_collisions import load_representations_binary
from src.utils.stats import RunningStats

def group_by_bin_number(root: Path) -> dict[int, list[str]]:
    pat = re.compile(r"\.bin-(\d+)$")
    groups = defaultdict(list)

    for p in root.iterdir():
        if p.is_file():
            m = pat.search(p.name)
            if m:
                groups[int(m.group(1))].append(
                    str(p.name).replace(m.group(1), '')[:-5]
                )

    return {k: sorted(v) for k, v in sorted(groups.items())}


def find_closest(
    data_dir: Path,
    datasets: list[str],
    model_name: str,
    run_args_name: str,
    norm_ord: int = 2,
    norm_type: Union[Literal['sum'], Literal['mean']] = 'sum',
    top_k: int = 100,
    max_samples: int = 10_000
):
        
    DATA_DIR  = data_dir / model_name
    ARGS_PATH = DATA_DIR / run_args_name

    with ARGS_PATH.open('r') as f:
        content = f.read()
    
    match = re.search(r"Hidden Size:\s*d\s*=\s*(\d+)", content)
    
    if not match:
        raise Exception(f'Hidden size not found in `{ARGS_PATH}`!')
    
    d = int(match.group(1))

    paths_dict = group_by_bin_number(DATA_DIR)

    stats_rows: list[dict[str, Any]] = []

    for layer, datasets in paths_dict.items():
        print(f'Computing distances for layer {layer}...')

        paths = {
            dataset: DATA_DIR / f'{dataset}.bin-{layer}'
            for dataset in datasets
        }

        print('\tLoading Samples...', end=' ', flush=True)
        samples = [
            [
                (name, *sample)
                for sample in load_representations_binary(path, d=d, flags='rb')
            ]
            for name, path in paths.items()
        ]
        samples = [item for sublist in samples for item in sublist]
        print('Done')

        print(f'\tEmbeddings computed: {len(samples):,}')
        
        print('\tCreating arrays...', end=' ', flush=True)
        grouped: dict[int, list[list]] = defaultdict(lambda: [[], [], [], [], []])
        for ds, sid, st, en, arr in samples:
            L = int(en - st + 1)
            bucket = grouped[L]
            bucket[0].append(ds)    # datasets
            bucket[1].append(sid)   # sample_ids
            bucket[2].append(st)    # start_ids
            bucket[3].append(en)    # end_ids
            bucket[4].append(arr)   # arrays

        groups_by_length: dict[int, tuple[list[str], list[int], list[int], list[int], list]] = {
            L: (
                vals[0][:max_samples], 
                vals[1][:max_samples], 
                vals[2][:max_samples], 
                vals[3][:max_samples], 
                vals[4][:max_samples]
            ) 
            for L, vals in grouped.items()
        }
        print(f'Done')

        for length, (datasets, sample_ids, start_ids, end_ids, arrays) in groups_by_length.items():
            stats = RunningStats()

            print(f'\tComputing distances for length {length}...')

            print('\t\tCreating tesnors...', end=' ', flush=True)
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            arrays = [torch.tensor(a, dtype=torch.float32, device=device) for a in arrays]
            arrays = torch.stack(arrays, dim=0)
            print(f'Done')

            zeros = [] 

            for i in (bar := tqdm(range(len(arrays) - 1), desc=f'{model_name} - {d}')):
                # "Pop" the first array in order to compare with all the rest
                array = arrays[0]        # shape: (D)
                arrays = arrays[1 :]     # shape: (N - i - 1, D)

                diffs = torch.linalg.norm(
                    arrays - array.unsqueeze(0), 
                    ord=norm_ord, 
                    dim=1
                ).cpu()

                if norm_type == 'mean':
                    diffs /= array.size(1)

                del array
                
                zero_mask = diffs.abs() < 1e-16
                nz = diffs[~zero_mask]
                stats.update_batch(nz)

                zero_idx = torch.nonzero(diffs.abs() < 1e-16, as_tuple=True)[0]
                zeros.extend([(i, i + 1 + x.item()) for x in zero_idx])

                del diffs, nz

                bar.set_postfix({
                    'min': stats.min,
                    'mean': stats.mean,
                    '# zeros': len(zeros),
                })


            stats_rows.append(stats.finalize_row(layer=layer, length=length))

            # sort ascending by true distance
            rows = []
            for i, j in zeros:
                rows.append({
                    'distance': 0.0,

                    'i': i,
                    'i_dataset': datasets[i],
                    'i_sample_idx': sample_ids[i],
                    'i_start_idx': start_ids[i],
                    'i_end_idx': end_ids[i],

                    'j': j,
                    'j_dataset': datasets[j],
                    'j_sample_idx': sample_ids[j],
                    'j_start_idx': start_ids[j],
                    'j_end_idx': end_ids[j],
                })

            out_csv = DATA_DIR / f'zeros-{layer}-{length}.csv'
            pd.DataFrame(rows).to_csv(out_csv, index=False)
            print(f'\t\tZeros saved to: {out_csv}\n\n')
    
    out_csv = DATA_DIR / f'layer-length-stats.csv'
    df = pd.DataFrame(stats_rows, columns=["layer", "length", "count", "mean", "std", "min", "max"])
    df.sort_values(by=["layer", "length"], inplace=True)
    df.to_csv(out_csv, index=False)

def parse_args():
    parser = argparse.ArgumentParser(description='Run inversion attack with given configuration.')

    parser.add_argument(
        '-dir', '--data-dir', 
        type=str, default='./data/dataset_exp',
        help='Path to the directory containing the dataset embeddings'
    )
    parser.add_argument(
        '--datasets', 
        type=str, nargs='+', 
        default=[
            'wikipedia', 
            'colossal_clean_crawled_corpus', 
            'arxiv_pile', 
            'github_python'
        ],
        help='Name of datasets to use. Should be subdirectories of `--data-dir`.'
    )
    parser.add_argument(
        '--id', '--model-id',
        type=str, default='roneneldan/TinyStories-1M',
        help='Name of HF model to use.'
    )
    parser.add_argument(
        '--run-args',
        type=str, default='run_args.txt',
        help='Name of the `run_args` text file.'
    )
    parser.add_argument(
        '--ord',
        type=int, default=2,
        help='Order of p-norm to use.'
    )
    parser.add_argument(
        '--norm-type',
        type=str, choices=['sum', 'mean'], 
        default='sum',
        help='Mean or Sum differences.'
    )
    parser.add_argument(
        '-k', '--top-k', 
        type=int, default=100,
        help='Top-K closest pairs to store.'
    )
    parser.add_argument(
        '-n', '--max-samples', 
        type=int, default=10_000,
        help='Top-K closest pairs to store.'
    )
    

    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()

    data_dir = Path(args.data_dir)
    dataset_names: list[str] = args.datasets

    model_id = args.id
    model_name = model_id.split('/')[-1]
    run_args = args.run_args

    norm_ord  = args.ord
    norm_type = args.norm_type

    top_k = args.top_k
    max_samples = args.max_samples

    find_closest(
        data_dir, dataset_names,
        model_name, run_args, 
        norm_ord, norm_type, 
        top_k, max_samples
    )

# ./scripts/find_close.sh --time 12:00:00 --top-k 100000