from collections import Counter
import datetime
import itertools
import os
import pickle
import socket
import warnings

import cvxpy as cp
import git
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import pairwise_distances, rbf_kernel
from scipy.stats import rankdata
from ruamel.yaml import YAML

DATASET_PATHSPEC = "./config/data_pathspec.yml"
EVALUATOR = "./config/default_evaluator.yml"

yaml = YAML(typ='safe')

def _validate_ranking_metric_inputs(y_true, y_score):
    assert len(y_true) == len(y_score)

def _get_top_k(y_true, y_score, k):
    if k > len(y_true):
        warnings.warn(f"Value of k ({k}) is greater than the number of items. Setting k to the number of items.")
        k = len(y_true)
    top_k_true = (rankdata(-np.array(y_true), method='dense') - 1)[-k:] # 2x check if we need to flip here
    top_k_pred = y_score[-k:]
    return top_k_true, top_k_pred


def ranking_accuracy(y_true, y_score):
    _validate_ranking_metric_inputs(y_true, y_score)
    correct = 0.
    total = 0.
    for i, j in itertools.combinations(range(len(y_true)), 2):
        true_rank = y_true[i] - y_true[j]
        pred_rank = y_score[i] - y_score[j]
        if true_rank * pred_rank > 0:
            correct += 1
        elif true_rank * pred_rank == 0:
            correct += 1 / 2
        total += 1
    return correct / total

def hitrate_at_k(k, y_true, y_score): # Default is ascending order
    _validate_ranking_metric_inputs(y_true, y_score)
    top_k_true, top_k_pred = _get_top_k(y_true, y_score, k)
    misses = Counter(top_k_pred) - Counter(top_k_true) # left over: all predicted top ks that aren't in the true list
    return 1 - misses.total() / k

""" 
    Deprecated -- use hitrate_at_k
"""
def precision_at_k(k, y_true, y_score): # Default is ascending order
    _validate_ranking_metric_inputs(y_true, y_score)
    top_k_true, top_k_pred = _get_top_k(y_true, y_score, k)
    misses = Counter(top_k_pred) - Counter(top_k_true) # left over: all predicted top ks that aren't in the true list
    return 1 - misses.total() / k

"""
    Deprecated -- use hitrate_at_k
"""
def recall_at_k(k, y_true, y_score):
    _validate_ranking_metric_inputs(y_true, y_score)
    top_k_true, top_k_pred = _get_top_k(y_true, y_score, k)
    extras = Counter(top_k_true) - Counter(top_k_pred) # left over: all true top ks that weren't in the predictions
    return 1 - extras.total() / k

def match_wo_replacement(dist1, dist2):
    # for every one in dist1, do 1:1 matching w/o replacement -- this is a mixed-integer linear program
    longer_dist = dist1 if len(dist1) > len(dist2) else dist2
    shorter_dist = dist2 if len(dist1) > len(dist2) else dist1

    matching_weights = cp.Variable(len(longer_dist), integer=True)
    objective = cp.Minimize(cp.sum(shorter_dist) - cp.sum(cp.multiply(matching_weights, longer_dist)))
    prob = cp.Problem(objective, [
        matching_weights.sum() == len(shorter_dist),
        matching_weights >= 0,
        matching_weights <= 1,
    ])
    prob.solve(solver=cp.SCIP)

    weight1 = matching_weights.value if len(dist1) == len(longer_dist) else np.ones_like(shorter_dist)
    weight2 = matching_weights.value if len(dist2) == len(longer_dist) else np.ones_like(shorter_dist)
    
    return weight1.astype(bool), weight2.astype(bool)

def match_w_replacement(dist1, dist2):
    dist_matrix = pairwise_distances(dist1.reshape(-1, 1), dist2.reshape(-1, 1))
    # match to dist1 (along)
    match_indices = np.argmax((dist_matrix == np.min(dist_matrix, axis=1, keepdims=True)) * np.random.rand(*dist_matrix.shape), axis=1) # break min-distance ties randomly
    weight1 = np.ones_like(dist1)

    weight2 = np.zeros_like(dist2)
    unique, counts = np.unique(match_indices, return_counts=True)
    weight2[unique] = counts
    return weight1, weight2


def get_dataset(dataset_name):
    with open(DATASET_PATHSPEC, "r") as f:
        dataset_cfg = yaml.load(f)
    # "./analytic/synthetic/synthetic_uniform.csv"
    path = dataset_cfg[dataset_name]["data"]
    embed_path = dataset_cfg[dataset_name]["plans"]
    data_config_file = dataset_cfg[dataset_name].get("config", None)
    df = pd.read_csv(path, index_col=0, low_memory=False)
    plan_df = pd.read_csv(embed_path, index_col=0, low_memory=False)
    return df, plan_df, data_config_file

def get_true_ranking(data_config_file):
    with open(data_config_file, "r") as f:
        data_cfg = yaml.load(f)
        plans = data_cfg["plans"]
    # descending order by upcoding parameter = ascending order by CATE
    return np.array(plans)

def save_model(save_path, meta_learner):
    with open(save_path, "wb") as f:
        pickle.dump(meta_learner, f, protocol=pickle.HIGHEST_PROTOCOL)
    print("Saved to", save_path)


def save_results(result_path, result_dict, ranking, true_ranking=None):
    result_dict["rank_pred"] = list(map(int, ranking))
    if true_ranking is not None:
        result_dict["rank_true"] = list(map(int, true_ranking))
    with open(result_path, "w") as f:
        json.dump(result_dict, f, sort_keys=True, indent=4)
    print("Saved results to", result_path)


def save_config(config_path, cfg):
    repo = git.Repo(search_parent_directories=True)
    cfg["run"] = {
        "hash": repo.head.object.hexsha,
        "date": str(datetime.datetime.now()),
        "hostname": socket.gethostname(),
    }
    if "SLURM_JOB_ID" in os.environ:
        cfg["run"]["slurm_jobid"] = os.environ["SLURM_JOB_ID"]
    with open(config_path, "w") as yf:
        yaml.dump(cfg, yf)
    print("Saved config to", config_path)
    
        
    
