import os

import torch
import numpy as np

import utils.const as C

points = None

def generate_random_points(dim):
    global points
    if points == None:
        points = torch.rand(C.NUM_POINTS, dim, device=C.DEVICE)
    return points

def rescale_random_points(points, range_val):
    point_dim_min = torch.min(points, dim=0).values
    point_dim_max = torch.max(points, dim=0).values
    points = (points - point_dim_min[None, :]) / (point_dim_max[None, :] - point_dim_min[None, :])
    points = points * 2 * range_val - range_val
    return points

def get_median_Ks(Ks):
    indices = Ks.keys()
    median_values = {}
    for (i, j) in indices:
        median_values[(i, j)] = np.median(Ks[(i, j)])
    return median_values

def get_max_Ks(Ks):
    median_values = {}
    for k, v in Ks.items():
        median_values[k] = np.percentile(v, 0.99)
    return median_values


def load_or_compute(
    cache_dir,
    filename,
    compute_fn,
):
    os.makedirs(cache_dir, exist_ok=True)
    path = os.path.join(cache_dir, filename)

    if os.path.exists(path):
        print(f"[cache] loading {path}")
        return torch.load(path, map_location=C.DEVICE)

    print(f"[cache] computing {path}")
    obj = compute_fn()
    torch.save(obj, path)
    return obj

import json
from pathlib import Path

def save_dicts_to_cache(dicts, subpath, filename):
    cache_dir = Path("cache") / subpath
    cache_dir.mkdir(parents=True, exist_ok=True)
    out_path = cache_dir / filename
    with out_path.open("w") as f:
        json.dump(dicts, f, indent=2)
    return out_path

def encode_keys(d):
    return {str(k): list(v) for k, v in d.items()}

import math
from collections import defaultdict

def bin_tuple_dict_log_to_tuple(data, base=2):
    bins = defaultdict()
    for (i, j), values in data.items():
        if i <= 0:
            continue
        bin_key_i = base ** int(math.floor(math.log(i, base)))
        bin_key_j = base ** int(math.floor(math.log(j, base)))
        if (bin_key_i, bin_key_j) not in bins:
            bins[(bin_key_i, bin_key_j)] = []
        bins[(bin_key_i, bin_key_j)].extend(values)
    return dict(bins)

def diagonalize(data):
    M = {}
    for (i, j), v in data.items():
        if not (i, i) in M:
            M[(i, i)] = []
        M[(i, i)].extend(v)
    return M

def get_mean_diagonal(data):
    M = {}
    for (i, j), v in data.items():
        if math.isnan(v):
            continue
        if not (i, i) in M:
            M[(i, i)] = []
        M[(i, i)].append(v)
    M = {k: np.mean(v) for k, v in M.items()}
    return M

def get_mean_deltas_diagonal(data):
    M = {}
    for (i, j), v in data.items():
        if math.isnan(v):
            continue
        if not (i, i) in M:
            M[(i, i)] = []
        M[(i, i)].append(v)
    M = {k: np.mean(v) for k, v in M.items()}
    D = {k: v / M[list(M.keys())[0]] for k, v in M.items()}
    return D

def get_max_deltas_diagonal(data):
    M = {}
    for (i, j), v in data.items():
        if math.isnan(v):
            continue
        if not (i, i) in M:
            M[(i, i)] = []
        M[(i, i)].append(v)
    M = {k: max(v) for k, v in M.items()}
    D = {k: v / M[list(M.keys())[0]] for k, v in M.items()}
    return D

def get_median_deltas_diagonal(data):
    M = {}
    for (i, j), v in data.items():
        if math.isnan(v):
            continue
        if not (i, i) in M:
            M[(i, i)] = []
        M[(i, i)].append(v)
    M = {k: np.median(v) for k, v in M.items()}
    D = {k: v / M[list(M.keys())[0]] for k, v in M.items()}
    return D

def load_from_cache(subpath, filename):
    path = Path("cache") / subpath / filename
    with path.open("r") as f:
        return decode_keys(json.load(f))
    
def decode_keys(d):
    return {eval(k): v for k, v in d.items()}