import numpy as np
import pandas as pd
import copy

from scipy.interpolate import griddata

def compute_variogram(spatial_dist_matrix, wasser_dist_matrix, spatial_lag=0.0001):
    idx = np.where(wasser_dist_matrix > 0)
    spatial_idx = (spatial_dist_matrix[idx].flatten() // spatial_lag).astype(int)

    spatial_wasser_df = pd.DataFrame.from_dict({"spatial": spatial_idx, "wasser": wasser_dist_matrix[idx].flatten()})
    spatial_wasser_df = spatial_wasser_df.groupby("spatial").mean()

    return spatial_wasser_df

def compute_match_grid(spatial_dist_matrix, wasser_dist_matrix, cluster_match_matrix, spatial_lag=0.0001, spatial_ticks=2500, wasser_lag=0.1, wasser_ticks=120):
    grid = np.zeros((spatial_ticks, wasser_ticks))
    idx = np.where(cluster_match_matrix > 0)
    spatial_idx = (spatial_dist_matrix[idx].flatten() // spatial_lag).astype(int)
    wasser_idx = (wasser_dist_matrix[idx].flatten() // wasser_lag).astype(int)

    c, cnt = np.unique(np.array([spatial_idx, wasser_idx]), axis=1, return_counts=True)
    grid[c[0], c[1]] += cnt

    grid_denom = np.zeros((spatial_ticks, wasser_ticks))
    idx = np.where(wasser_dist_matrix > 0)
    spatial_idx = (spatial_dist_matrix[idx].flatten() // spatial_lag).astype(int)
    wasser_idx = (wasser_dist_matrix[idx].flatten() // wasser_lag).astype(int)

    c, cnt = np.unique(np.array([spatial_idx, wasser_idx]), axis=1, return_counts=True)
    grid_denom[c[0], c[1]] += cnt
    grid_denom[grid_denom==0] = 1

    grid /= grid_denom

    return grid

def compute_match_grid_by_cluster(spatial_dist_matrix, wasser_dist_matrix, cluster_match_matrix, cluster_idx, spatial_lag=0.0001, spatial_ticks=2500, wasser_lag=0.1, wasser_ticks=120):
    cluster_filter = np.zeros_like(cluster_match_matrix)
    cluster_filter[cluster_idx] = 1
    cluster_filter[:,cluster_idx] = 1

    grid = np.zeros((spatial_ticks, wasser_ticks))
    idx = (cluster_match_matrix > 0) & (cluster_filter > 0)
    spatial_idx = (spatial_dist_matrix[idx].flatten() // spatial_lag).astype(int)
    wasser_idx = (wasser_dist_matrix[idx].flatten() // wasser_lag).astype(int)

    c, cnt = np.unique(np.array([spatial_idx, wasser_idx]), axis=1, return_counts=True)
    grid[c[0], c[1]] += cnt

    grid_denom = np.zeros((spatial_ticks, wasser_ticks))
    idx = (wasser_dist_matrix > 0) & (cluster_filter > 0)
    spatial_idx = (spatial_dist_matrix[idx].flatten() // spatial_lag).astype(int)
    wasser_idx = (wasser_dist_matrix[idx].flatten() // wasser_lag).astype(int)

    c, cnt = np.unique(np.array([spatial_idx, wasser_idx]), axis=1, return_counts=True)
    grid_denom[c[0], c[1]] += cnt
    grid_denom[grid_denom==0] = 1

    grid /= grid_denom

    return grid

def compute_density_grid(spatial_dist_matrix, wasser_dist_matrix, spatial_lag=0.0001, spatial_ticks=2500, wasser_lag=0.1, wasser_ticks=120):
    # grid = np.zeros((spatial_ticks, wasser_ticks))
    # idx = np.where(cluster_match_matrix > 0)
    # spatial_idx = (spatial_dist_matrix[idx].flatten() // spatial_lag).astype(int)
    # wasser_idx = (wasser_dist_matrix[idx].flatten() // wasser_lag).astype(int)

    # c, cnt = np.unique(np.array([spatial_idx, wasser_idx]), axis=1, return_counts=True)
    # grid[c[0], c[1]] += cnt

    grid_denom = np.zeros((spatial_ticks, wasser_ticks))
    idx = np.where(wasser_dist_matrix > 0)
    spatial_idx = (spatial_dist_matrix[idx].flatten() // spatial_lag).astype(int)
    wasser_idx = (wasser_dist_matrix[idx].flatten() // wasser_lag).astype(int)

    c, cnt = np.unique(np.array([spatial_idx, wasser_idx]), axis=1, return_counts=True)
    grid_denom[c[0], c[1]] += cnt
    grid_denom[grid_denom==0] = 1

    # grid /= grid_denom

    return grid_denom


dname = "inat"
radius = 30
cluster = 0
spatial_lag, spatial_ticks, wasser_lag, wasser_ticks = 0.0001, 2500, 0.01, 600

data = np.load("checkpoints/{}/{}-r{}-matrix.npz".format(dname, dname, radius))
spatial_dist_matrix, wasser_dist_matrix, cluster_match_matrix = data["spatial"], data["wasser"], data["match"]

print(np.max(spatial_dist_matrix.flatten()), np.max(wasser_dist_matrix.flatten()))

data = np.load("checkpoints/{}/{}-r{}-fitting.npz".format(dname, dname, radius))
valid_labels = data["labels"]

grid_denom = compute_density_grid(spatial_dist_matrix, wasser_dist_matrix, spatial_lag, spatial_ticks, wasser_lag, wasser_ticks)

np.savez("checkpoints/{}/{}-r{}-c{}-grid_denom-{}-{}-{}-{}".format(dname, dname, radius, cluster, spatial_lag, spatial_ticks, wasser_lag, wasser_ticks), grid_denom=grid_denom)


# cluster_idx = (valid_labels==cluster)
# print(np.sum(cluster_idx))
# grid_z = compute_match_grid_by_cluster(spatial_dist_matrix, wasser_dist_matrix, cluster_match_matrix, cluster_idx, spatial_lag, spatial_ticks, wasser_lag, wasser_ticks)
#
# spatial_wasser_df = compute_variogram(spatial_dist_matrix, wasser_dist_matrix, spatial_lag)
#
# print(grid_z.shape, spatial_wasser_df.shape)
#
# np.savez("checkpoints/{}/{}-r{}-c{}-grid-{}-{}-{}-{}".format(dname, dname, radius, cluster, spatial_lag, spatial_ticks, wasser_lag, wasser_ticks), prob=grid_z, spatial=spatial_wasser_df.index.to_numpy(), vario=spatial_wasser_df.to_numpy())
#
