import sys
import numpy as np
import pandas as pd

from sklearn.metrics.pairwise import haversine_distances
import matplotlib.pyplot as plt
import seaborn as sns
import os
import ast


def sample_mask(shape, p=0.002, p_noise=[0.1, 0.2], mode="random", pos=None, adj=None, sample_strategy=None):
    assert mode in ["random", "road", "mix"], "The missing mode must be 'random' or 'road' or 'mix'."
    rand = np.random.random
    mask = np.zeros(shape).astype(bool)
    # if mode == "random" or mode == "mix":
    #     mask = mask | (rand(mask.shape) < p)
    # if mode == "road" or mode == "mix":
    #     # Add weighted selection based on pos
    #     if sample_strategy == "region":
    #         num_total = int(p_noise * shape[1])
    #
    #         # Normalize pos[:, 0] to weights (more weight for larger pos[0])
    #         weights = pos[:, 1].astype(float)
    #         weights_mean = weights.mean()
    #         weights = np.array([0 if w < weights_mean else 1 for w in weights])
    #         probabilities = weights / weights.sum()
    #         chosen_indices = np.random.choice(len(pos), size=num_total, replace=False, p=probabilities)
    #         rand_mask = np.zeros(shape[1], dtype=bool)
    #         rand_mask[chosen_indices] = True
    #     elif sample_strategy == "degree":
    #         # Prioritize nodes with low degree
    #         num_total = int(p_noise * shape[1])
    #
    #         # Calculate degree (number of connections) for each node
    #         degrees = adj.sum(axis=1).A1 if hasattr(adj, "A1") else np.array(adj.sum(axis=1)).flatten()
    #
    #         # Exclude nodes with degree 0
    #         non_zero_indices = np.where(degrees > 0)[0]
    #         non_zero_degrees = degrees[non_zero_indices]
    #
    #         # Assign higher weights to nodes with smaller degrees (can use 1 / (degree + epsilon) form)
    #         epsilon = 1e-6
    #         inv_degrees = 1 / (non_zero_degrees + epsilon) ** 0.5
    #         probabilities = inv_degrees / inv_degrees.sum()
    #
    #         # Randomly select from nodes with non-zero degrees
    #         chosen_indices = np.random.choice(non_zero_indices, size=num_total, replace=False, p=probabilities)
    #
    #         rand_mask = np.zeros(shape[1], dtype=bool)
    #         rand_mask[chosen_indices] = True
    #     else:
    #         road_shape = mask.shape[1]
    #         rand_ = rand(road_shape)
    #         # p_noise: [val_rate, test_rate]
    #         rand_mask_val = rand_ < p_noise[0]
    #         rand_mask_test = (p_noise[0] <= rand_) & (rand_ < p_noise[1])
    road_shape = mask.shape[1]
    rand_ = rand(road_shape)
    # p_noise: [val_rate, test_rate]
    rand_mask_val = rand_ < p_noise[0]
    rand_mask_test = (p_noise[0] <= rand_) & (rand_ < p_noise[1])
    road_mask_val = np.zeros(shape).astype(bool)
    road_mask_val[:, rand_mask_val] = True
    road_mask_test = np.zeros(shape).astype(bool)
    road_mask_test[:, rand_mask_test] = True
    return road_mask_val.astype('uint8'), road_mask_test.astype('uint8')


def compute_mean(x, index=None):
    """Compute the mean values for each datetime. The mean is first computed hourly over the week of the year.
    Further NaN values are computed using hourly mean over the same month through the years. If other NaN are present,
    they are removed using the mean of the sole hours. Hoping reasonably that there is at least a non-NaN entry of the
    same hour of the NaN datetime in all the dataset."""
    if isinstance(x, np.ndarray) and index is not None:
        shape = x.shape
        x = x.reshape((shape[0], -1))
        df_mean = pd.DataFrame(x, index=index)
    else:
        df_mean = x.copy()
    cond0 = [df_mean.index.year, df_mean.index.isocalendar().week, df_mean.index.hour]
    cond1 = [df_mean.index.year, df_mean.index.month, df_mean.index.hour]
    conditions = [cond0, cond1, cond1[1:], cond1[2:]]
    while df_mean.isna().values.sum() and len(conditions):
        nan_mean = df_mean.groupby(conditions[0]).transform(np.nanmean)
        df_mean = df_mean.fillna(nan_mean)
        conditions = conditions[1:]
    if df_mean.isna().values.sum():
        df_mean = df_mean.fillna(method='ffill')
        df_mean = df_mean.fillna(method='bfill')
    if isinstance(x, np.ndarray):
        df_mean = df_mean.values.reshape(shape)
    return df_mean


def geographical_distance(x=None, to_rad=True):
    """
    Compute the as-the-crow-flies distance between every pair of samples in `x`. The first dimension of each point is
    assumed to be the latitude, the second is the longitude. The inputs is assumed to be in degrees. If it is not the
    case, `to_rad` must be set to False. The dimension of the data must be 2.

    Parameters
    ----------
    x : pd.DataFrame or np.ndarray
        array_like structure of shape (n_samples_2, 2).
    to_rad : bool
        whether to convert inputs to radians (provided that they are in degrees).

    Returns
    -------
    distances :
        The distance between the points in kilometers.
    """
    _AVG_EARTH_RADIUS_KM = 6371.0088

    # Extract values of X if it is a DataFrame, else assume it is 2-dim array of lat-lon pairs
    latlon_pairs = x.values if isinstance(x, pd.DataFrame) else x

    # If the input values are in degrees, convert them in radians
    if to_rad:
        latlon_pairs = np.vectorize(np.radians)(latlon_pairs)

    distances = haversine_distances(latlon_pairs) * _AVG_EARTH_RADIUS_KM

    # Cast response
    if isinstance(x, pd.DataFrame):
        res = pd.DataFrame(distances, x.index, x.index)
    else:
        res = distances

    return res


def infer_mask(df, infer_from='next'):
    """Infer evaluation mask from DataFrame. In the evaluation mask a value is 1 if it is present in the DataFrame and
    absent in the `infer_from` month.

    @param pd.DataFrame df: the DataFrame.
    @param str infer_from: denotes from which month the evaluation value must be inferred.
    Can be either `previous` or `next`.
    @return: pd.DataFrame test_mask: the evaluation mask for the DataFrame
    """
    mask = (~df.isna()).astype('uint8')
    test_mask = pd.DataFrame(index=mask.index, columns=mask.columns, data=0).astype('uint8')
    if infer_from == 'previous':
        offset = -1
    elif infer_from == 'next':
        offset = 1
    else:
        raise ValueError('infer_from can only be one of %s' % ['previous', 'next'])
    months = sorted(set(zip(mask.index.year, mask.index.month)))
    length = len(months)
    for i in range(length):
        j = (i + offset) % length
        year_i, month_i = months[i]
        year_j, month_j = months[j]
        mask_j = mask[(mask.index.year == year_j) & (mask.index.month == month_j)]
        mask_i = mask_j.shift(1, pd.DateOffset(months=12 * (year_i - year_j) + (month_i - month_j)))
        mask_i = mask_i[~mask_i.index.duplicated(keep='first')]
        mask_i = mask_i[np.in1d(mask_i.index, mask.index)]
        test_mask.loc[mask_i.index] = ~mask_i.loc[mask_i.index] & mask.loc[mask_i.index]
    return test_mask


def prediction_dataframe(y, index, columns=None, aggregate_by='mean'):
    """Aggregate batched predictions in a single DataFrame.

    @param (list or np.ndarray) y: the list of predictions.
    @param (list or np.ndarray) index: the list of time indexes coupled with the predictions.
    @param (list or pd.Index) columns: the columns of the returned DataFrame.
    @param (str or list) aggregate_by: how to aggregate the predictions in case there are more than one for a step.
    - `mean`: take the mean of the predictions
    - `central`: take the prediction at the central position, assuming that the predictions are ordered chronologically
    - `smooth_central`: average the predictions weighted by a gaussian signal with std=1
    - `last`: take the last prediction
    @return: pd.DataFrame df: the evaluation mask for the DataFrame
    """
    dfs = [pd.DataFrame(data=data.reshape(data.shape[:2]), index=idx, columns=columns) for data, idx in zip(y, index)]
    df = pd.concat(dfs)
    preds_by_step = df.groupby(df.index)
    # aggregate according passed methods
    aggr_methods = ensure_list(aggregate_by)
    dfs = []
    for aggr_by in aggr_methods:
        if aggr_by == 'mean':
            dfs.append(preds_by_step.mean())
        elif aggr_by == 'central':
            dfs.append(preds_by_step.aggregate(lambda x: x[int(len(x) // 2)]))
        elif aggr_by == 'smooth_central':
            from scipy.signal import gaussian
            dfs.append(preds_by_step.aggregate(lambda x: np.average(x, weights=gaussian(len(x), 1))))
        elif aggr_by == 'last':
            dfs.append(preds_by_step.aggregate(lambda x: x[0]))  # first imputation has missing value in last position
        else:
            raise ValueError('aggregate_by can only be one of %s' % ['mean', 'central' 'smooth_central', 'last'])
    if isinstance(aggregate_by, str):
        return dfs[0]
    return dfs


def ensure_list(obj):
    if isinstance(obj, (list, tuple)):
        return list(obj)
    else:
        return [obj]


def missing_val_lens(mask):
    m = np.concatenate([np.zeros((1, mask.shape[1])),
                        (~mask.astype('bool')).astype('int'),
                        np.zeros((1, mask.shape[1]))])
    mdiff = np.diff(m, axis=0)
    lens = []
    for c in range(m.shape[1]):
        mj, = mdiff[:, c].nonzero()
        diff = np.diff(mj)[::2]
        lens.extend(list(diff))
    return lens


def disjoint_months(dataset, months=None, synch_mode='window'):
    idxs = np.arange(len(dataset))
    months = ensure_list(months)
    # divide indices according to window or horizon
    if synch_mode == 'window':
        start, end = 0, dataset.window - 1
    elif synch_mode == 'horizon':
        start, end = dataset.horizon_offset, dataset.horizon_offset + dataset.horizon - 1
    else:
        raise ValueError('synch_mode can only be one of %s' % ['window', 'horizon'])
    # after idxs
    start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)
    end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)
    idxs_in_months = start_in_months & end_in_months
    after_idxs = idxs[idxs_in_months]
    # previous idxs
    months = np.setdiff1d(np.arange(1, 13), months)
    start_in_months = np.in1d(dataset.index[dataset._indices + start].month, months)
    end_in_months = np.in1d(dataset.index[dataset._indices + end].month, months)
    idxs_in_months = start_in_months & end_in_months
    prev_idxs = idxs[idxs_in_months]
    return prev_idxs, after_idxs


def thresholded_gaussian_kernel(x, theta=None, threshold=None, threshold_on_input=False):
    if theta is None:
        theta = np.std(x)
    weights = np.exp(-np.square(x / theta))
    if threshold is not None:
        mask = x > threshold if threshold_on_input else weights < threshold
        weights[mask] = 0.
    return weights


def plot_weights(adj, output_path, pos=None, val_nodes=None, test_nodes=None):
    # heatmap
    sns.heatmap(adj, cmap='Blues', cbar=True, xticklabels=False, yticklabels=False)
    plt.savefig(os.path.join(output_path, 'adj.png'), dpi=300)
    plt.close()

    # degree distribution
    plot_degree_distribution(adj, output_path)

    # graph
    if pos is not None:
        plt.scatter(pos[:, 0], pos[:, 1], c='blue', s=10)
        if val_nodes is not None:
            plt.scatter(pos[val_nodes, 0], pos[val_nodes, 1], c='green', s=10)
        if test_nodes is not None:
            plt.scatter(pos[test_nodes, 0], pos[test_nodes, 1], c='red', s=10)
        # Draw node connections
        for i in range(len(pos)):
            for j in range(len(pos)):
                if adj[i, j] > 0:
                    plt.plot([pos[i, 0], pos[j, 0]], [pos[i, 1], pos[j, 1]], 'k-', lw=0.5)
        plt.savefig(os.path.join(output_path, 'weights_pos.png'), dpi=300)
        plt.close()


def plot_degree_distribution(adj, output_path):
    # Set non-zero elements to 1 (remove weights)
    adj_ = adj.copy()
    adj_[adj_ > 0] = 1

    # Calculate out-degree (row sum) and in-degree (column sum)
    out_degrees = adj_.sum(axis=1)
    in_degrees = adj_.sum(axis=0)

    # Plot out-degree distribution
    unique_out = np.unique(out_degrees)
    bins_out = int(unique_out.max() - unique_out.min() + 1)
    plt.figure(figsize=(8, 6))
    plt.hist(out_degrees, bins=bins_out, color='skyblue', alpha=0.7)
    plt.xlabel('Out-degree')
    plt.ylabel('Frequency')
    plt.title('Out-degree Distribution')
    plt.savefig(os.path.join(output_path, 'out_degree_distribution.png'), dpi=300)
    plt.close()

    # Plot in-degree distribution
    unique_in = np.unique(in_degrees)
    bins_in = int(unique_in.max() - unique_in.min() + 1)
    plt.figure(figsize=(8, 6))
    plt.hist(in_degrees, bins=bins_in, color='lightcoral', alpha=0.7)
    plt.xlabel('In-degree')
    plt.ylabel('Frequency')
    plt.title('In-degree Distribution')
    plt.savefig(os.path.join(output_path, 'in_degree_distribution.png'), dpi=300)
    plt.close()


def parse_mask_ratio(value):
    try:
        # Try to convert directly to float (single value)
        return float(value)
    except ValueError:
        # If failed, try to parse as list
        try:
            parsed = ast.literal_eval(value)
            if isinstance(parsed, list):
                assert all(isinstance(v, (float, int)) for v in parsed)
                return parsed
            else:
                raise ValueError("known-mask-ratio must be a float or list of floats.")
        except Exception as e:
            raise ValueError(f"Invalid format for known-mask-ratio: {value}") from e


def adj_keep_idx(adj):
    adj = np.array(adj)
    all_idx = np.arange(adj.shape[0])

    zero_row_idx = np.where(~adj.any(axis=1))[0]
    zero_col_idx = np.where(~adj.any(axis=0))[0]

    drop_idx = np.intersect1d(zero_row_idx, zero_col_idx)
    keep_idx = np.setdiff1d(all_idx, drop_idx)

    return keep_idx.tolist()


def isolate_nodes_detected(adj):
    keep_idx = adj_keep_idx(adj)
    if len(keep_idx) == len(adj):
        return False
    else:
        return True