import os
import re
import pickle
import pandas as pd
import numpy as np
import torch
from typing import List, Dict, Tuple, Union
import yaml

# tune for compressing softmax distributions
digitize_table_high = [i/100 for i in range(101)] + [i/1000 for i in range(101) if i%10!=0] + list(np.logspace(-2.1, -10, num=65))
digitize_table_high = np.asarray(list(reversed(sorted(digitize_table_high)))).astype(np.float32)

# second table for log softmax (nicer for lm outputs to avoid numerical fun)
digitize_table_high_lsm = [i/100 for i in range(101)] + [i/1000 for i in range(101) if i%10!=0] + list(np.logspace(-2.1, -10, num=65))
digitize_table_high_lsm = np.log(np.asarray(list(reversed(sorted(digitize_table_high_lsm)))).astype(np.float32))


def safe_numpy(t):
    if t.dtype==torch.bfloat16:
        t = t.to(device='cpu')
        return t.to(dtype=torch.float).numpy()
    else:
        return t.numpy()


def digitize_nparray(arr: np.ndarray, table='sm', verbose=0):
    # table is either 'sm' or 'lsm'
    if table=='sm':
        assert np.alltrue(arr <= 1.) and np.alltrue(arr >= 0)
        return np.digitize(arr, digitize_table_high).astype(np.uint8)
    elif table=='lsm':
        return np.digitize(arr, digitize_table_high_lsm).astype(np.uint8)
    else:
        raise ValueError(f'Invalid digitization table: {table}')


def undigitize_nparray(arr: np.ndarray, table='sm', verbose=0, apply_correction_dim=None):
    # table is either 'sm' or 'lsm'
    if table=='sm':
        res = digitize_table_high[arr]
        if apply_correction_dim is not None:
            # compute the factor by which to adjust
            multiplicative_factor = (1/res.sum(apply_correction_dim)).expand_dim(apply_correction_dim)
            res = res*multiplicative_factor
        return res
    elif table=='lsm':
        res = digitize_table_high_lsm[arr]
        # for correction just use softmax instead of exp if decoding or log softmax
        return res
    else: 
        raise ValueError(f'Invalid digitization table: {table}')

#######################################################################################
# The following are the tools for better storage of experiment outcomes
#######################################################################################
def numpify_array(arr):
    if isinstance(arr, np.ndarray):
        return arr
    elif isinstance(arr, torch.Tensor):
        return safe_numpy(arr)
    elif isinstance(arr, List) or isinstance(arr, Tuple):
        subarrs = [numpify_array(a) for a in arr]
        subarrs = np.stack(subarrs, axis=0)
        return subarrs
    elif np.isscalar(arr):
        return np.array(arr)
    else:
        print(type(arr), arr)
        # FIXME: 
        raise NotImplementedError()

def check_if_array_type(arr):
    if isinstance(arr, np.ndarray) or isinstance(arr, torch.Tensor):
        return True
    elif isinstance(arr, List) or isinstance(arr, Tuple):
        # print(len(arr), [check_if_array_type(a) for a in arr])
        return all([check_if_array_type(a) for a in arr])
    else:
        # print(type(arr), arr)
        return False

def check_array_dim_compatibility(arrlist):
    example = arrlist[0].shape
    return all([arr.shape==example for arr in arrlist])


import subprocess
import secrets


def save_records_compact(
    records: List[Dict], 
    path, 
    digitize_conf={}, 
    verbose=True,
    use_pigz=False,
    pigz_fast_temp_path=None,
):
    path = os.path.abspath(path)
    if pigz_fast_temp_path is None: pigz_fast_temp_path=path+str(secrets.randbelow(9000000))
    else: pigz_fast_temp_path = os.path.abspath(pigz_fast_temp_path)+str(secrets.randbelow(9000000))
    assert all(['.' not in k for k in records[0].keys()]), "'.' not allowed in the keynames (used for technical purposes)!!!"
    vector_keys = {k for k, v in records[0].items() if check_if_array_type(v)}
    general_keys = {k for k, v in records[0].items() if k not in vector_keys}
    print(general_keys, vector_keys) if verbose else None
    # print(records[0])
    # split the records
    general_records = [{k:v for k,v in rec.items() if k in general_keys} for rec in records]
    vector_records = {k:[numpify_array(rec[k]) for rec in records] for k in vector_keys}
    # handle compression where applicable
    vector_records = {k if k not in digitize_conf.keys() else f"dig.{k}":[nparr if k not in digitize_conf.keys() else digitize_nparray(nparr, **digitize_conf[k]) for nparr in vs] for k, vs in vector_records.items()}
    print("Converted an digitized the input") if verbose else None

    # summarize and stack where applicable
    vector_records_comp = {k: v for k,v in vector_records.items() if check_array_dim_compatibility(v)}
    vector_records_disc = {k: v for k,v in vector_records.items() if not check_array_dim_compatibility(v)}
    vector_records_condensed = {f"all.{k}": np.stack(v, axis=0) for k,v in vector_records_comp.items()}
    vector_records_uncondensable = {f"{i:05d}.{k}": v[i] for k, v in vector_records_disc.items() for i, _ in enumerate(v)} 
    print("Stacked arrays where possible.") if verbose else None
    
    # dump 
    print("Saving parquet...") if verbose else None
    gen_df = pd.DataFrame.from_records(general_records)
    gen_df.to_parquet(path+'.parquet')
    if use_pigz:
        print("Saving uncompressed npz...") if verbose else None
        np.savez(pigz_fast_temp_path, **vector_records_condensed, **vector_records_uncondensable) #, allow_pickle=True) # allow pickle to have string lists
        # run pigz
        print("Compressing npz using pigz...") if verbose else None
        with open(path+".npz.gz", 'wb') as outfile:
            subprocess.call(["pigz", "-9", pigz_fast_temp_path+'.npz', "-c"], stdout=outfile)
        print("Cleaning up temp...") if verbose else None
        subprocess.call(["rm", pigz_fast_temp_path+'.npz'])
    else:
        np.savez_compressed(path, **vector_records_condensed, **vector_records_uncondensable)
        print("Saving uncompressed npz...") if verbose else None
    print("Gz dump done...") if verbose else None
    # digitization_config if present
    if len(digitize_conf) > 0:
        with open(path+".dig.yaml", 'wt') as f:
            yaml.safe_dump(digitize_conf, f)


def load_records_from_compact(path, pigz_fast_temp_path=None, load_arrays = True):
    path = os.path.abspath(path)
    if pigz_fast_temp_path is None: pigz_fast_temp_path=path+str(secrets.randbelow(9000000))
    else: pigz_fast_temp_path = os.path.abspath(pigz_fast_temp_path)+str(secrets.randbelow(9000000))

    # find subnanmes
    dirname =  os.path.dirname(path)
    fname = os.path.basename(path)
    print(dirname, fname)
    related_files = [f for f in os.listdir(dirname) if f.startswith(fname)]
    print(related_files)
    assert f"{fname}.parquet" in related_files, "Must at least have a parquet file!"
    parquet_part = pd.read_parquet(os.path.join(dirname, f"{fname}.parquet"))
    records = parquet_part.to_dict('records')

    path_npy = os.path.join(dirname, f"{fname}.npz")
    path_pigz = os.path.join(dirname, f"{fname}.npz.gz")

    # now load the arrays
    if os.path.exists(path_npy) and load_arrays:
        # load from npz directly
        with np.load(os.path.join(dirname, f"{fname}.npz"), allow_pickle=True) as arrays:
            print(arrays.keys())
            darrays = {k:v for k, v in arrays.items()}
    elif os.path.exists(path_pigz) and load_arrays:
        # extract using pigz
        with open(pigz_fast_temp_path, 'wb') as outfile:
            subprocess.call(["pigz", "-d", "-k", "-9", path_pigz, "-c"], stdout=outfile)
        # load from the temp file
        with np.load(pigz_fast_temp_path, allow_pickle=True) as arrays: # allow pickle to load strings
            print(arrays.keys())
            darrays = {k:v for k, v in arrays.items()}
        # delete the temp file
        subprocess.call(["rm", pigz_fast_temp_path])
    else:
        # no arrays to load, return
        return records
    
    # first undigitize digitized arrays
    if os.path.exists(os.path.join(dirname, f"{fname}.dig.yaml")):
        with open(os.path.join(dirname, f"{fname}.dig.yaml"), 'rt') as f:
            digconfs = yaml.safe_load(f)

        for k in darrays.keys():
            namesplit = k.split('.')
            if 'dig' in namesplit:
                darrays[k] = undigitize_nparray(darrays[k], **digconfs[namesplit[-1]])

    # now we can assign the corresponding entries to the records
    for k in darrays:
        namesplit = k.split('.')
        cname = namesplit[-1]
        if namesplit[0]=='all':
            # joint array, assign an entry to each record
            assert len(records) == darrays[k].shape[0], f"Shape mismatch: records {len(records)}; array: {k}: {darrays[k].shape}"
            # assign one by one to records
            for i, rec in enumerate(records):
                rec[cname] = darrays[k][i]            
        else:
            # assign to the specified record
            record_id = int(namesplit[0])
            records[record_id][cname] = darrays[k]

    # should be done now    
    return records


def digitization_stats_check(undigitized, sm_dim=-1, log=True):
    if log:
        restored_probs_d=1.-np.exp(undigitized).sum(sm_dim)
    else:
        restored_probs_d=1-undigitized.sum(sm_dim)
    return restored_probs_d.mean(), restored_probs_d.var(), restored_probs_d.max()


def load_from_compact_chunks(dump_dir: str, range_start=0, range_end=-1, restrict_sets=None, **loader_kwargs):
    # assumes that the chunks are dumped in prefix_05d:start_range_05d:end_range
    applicable_files = [f for f in os.listdir(dump_dir) if re.fullmatch(r'.*\d{5}_\d{5}\..*', f)]
    applicable_files = [
        {
            'file': f, 
            'prefix': f.split('.')[0].split('_')[0],
            'start_range': int(f.split('.')[0].split('_')[1]),
            'end_range': int(f.split('.')[0].split('_')[2]),
            'extension': '.'.join(f.split('.')[1:])
        } for f in applicable_files]
    finfo = pd.DataFrame.from_records(applicable_files)
    guide = finfo[finfo.extension=='parquet'].groupby('prefix')
    retval = {}
    for g in guide.groups:
        curset = guide.get_group(g)
        curset = curset.sort_values(by='start_range')
        if restrict_sets is not None and g not in restrict_sets:
            print(f"Skipping {g}: not in specified subsets of data {restrict_sets}")
            continue
        last_val = 0
        accum = []
        for _, r in curset.iterrows():
            if r.end_range<=range_start or (r.start_range>=range_end and range_end>=0):
                print(f"Skipping {r.file}, outside of specified range {range_start}:{range_end}.")
                continue
            if r.start_range!=last_val:
                print(f"Warning: last val is {last_val}, the coming file starts from {r.start_range}. Something might be missing!")
            last_val = r.end_range+1
            # print(r)
            chunk = load_records_from_compact(os.path.join(dump_dir, f"{r.prefix}_{r.start_range:05d}_{r.end_range:05d}"), **loader_kwargs)
            # trim to return appropriate number of stuff
            cut_start = range_start - r.start_range if r.start_range < range_start and r.end_range > range_start else 0
            cut_end = range_end - r.end_range - 1 if r.start_range < range_end and r.end_range > range_end else len(chunk)
            print(f"Selecting range {cut_start}:{cut_end}")
            accum += chunk[cut_start:cut_end]
        retval[g] = accum
    return retval

def iterate_over_compact_chunks(dump_dir: str, range_start=0, range_end=-1, restrict_sets=None, return_file_prefix=False, **loader_kwargs):
    applicable_files = [f for f in os.listdir(dump_dir) if re.fullmatch(r'.*\d{5}_\d{5}\..*', f)]
    applicable_files = [
        {
            'file': f, 
            'prefix': f.split('.')[0].split('_')[0],
            'start_range': int(f.split('.')[0].split('_')[1]),
            'end_range': int(f.split('.')[0].split('_')[2]),
            'extension': '.'.join(f.split('.')[1:])
        } for f in applicable_files]
    finfo = pd.DataFrame.from_records(applicable_files)
    guide = finfo[finfo.extension=='parquet'].groupby(by=['start_range', 'end_range'])

    last_val = 0
    for g in guide.groups:
        curset = guide.get_group(g)
        curset = curset[curset.prefix.isin(restrict_sets)] if restrict_sets is not None else curset

        retval = {}

        # check if we are looking at the right chunk
        skip_iter = False
        ranges = np.stack([[r.start_range, r.end_range] for _, r in curset.iterrows()], axis=0)
        assert np.all((ranges-ranges[0, :])==0), f"Bad ranging~! {ranges}"

        for _, r in curset.iterrows():
            # check for possible issues
            if r.end_range<=range_start or (r.start_range>=range_end and range_end>=0):
                print(f"Skipping {r.file}, outside of specified range {range_start}:{range_end}.")
                skip_iter = True
            if r.start_range!=last_val:
                print(f"Warning: last val is {last_val}, the coming file starts from {r.start_range}. Something might be missing!")
        last_val = r.end_range+1
        if skip_iter:
            continue

        # trim to return appropriate number of stuff
        cut_start = range_start - r.start_range if r.start_range < range_start and r.end_range > range_start else 0
        cut_end = range_end - r.end_range - 1 if r.start_range < range_end and r.end_range > range_end else r.end_range-r.start_range+1
        print(f"Selecting range {cut_start}:{cut_end}")


        # load the chunks for all sets for a given range
        for _, r in curset.iterrows():
            # print(r)
            g = r['prefix']
            retval[g] = []
            chunk = load_records_from_compact(os.path.join(dump_dir, f"{r.prefix}_{r.start_range:05d}_{r.end_range:05d}"), **loader_kwargs)
            retval[g] += chunk[cut_start:cut_end]

        # yield one chunk at a time
        if not return_file_prefix: 
            yield retval
        else:
            yield retval, os.path.join(dump_dir, f"{r.prefix}_{r.start_range:05d}_{r.end_range:05d}")
        del retval
