import os
import torch
import pandas as pd
import numpy as np
import itertools


def get_l2_distance(x, y):
    x_repeat = torch.repeat_interleave(x, y.shape[0], dim=0)
    y_repeat = y.repeat(x.shape[0],1)
    l2_dist = torch.norm(x_repeat - y_repeat, dim=1).reshape(x.shape[0], -1)
    return l2_dist

def get_cosine_similarity(x, y):
    dot = torch.mm(x, torch.transpose(y, 0, 1))
    norm_x = torch.linalg.norm(x, axis=1).repeat(x.shape[0])
    norm_y = torch.linalg.norm(y, axis=1).repeat(y.shape[0],1).T.reshape(-1)
    norm = (norm_x * norm_y).reshape(y.shape[0], y.shape[0])
    return dot/norm

def get_topk_intersecting_neighbors(relationship_points, nb_neighbors, largest):
    '''
    - nb_neighbors: number of neighbors
    - relationship_points: how distant/similar is each point wrt all the other pints
    - largest: True if 
    '''
    return torch.topk(relationship_points, nb_neighbors, dim=0, largest=largest).indices.T


def get_ratio_intersection_neighbors(relationship_points_x, relationship_points_y, nb_neighbors, largest=True):
    topk_neighbors_x = get_topk_intersecting_neighbors(
        relationship_points_x, nb_neighbors, largest=largest)
    topk_neighbors_y = get_topk_intersecting_neighbors(
        relationship_points_y, nb_neighbors, largest=largest)
    ratio_intersection = []
    for neighbors_x, neighbors_y in zip(topk_neighbors_x, topk_neighbors_y):
        intersecting_neighbors = np.intersect1d(neighbors_x, neighbors_y)
        ratio_intersection.append(intersecting_neighbors.shape[0] / neighbors_x.shape[0])
    ratio_intersection = np.array(ratio_intersection)
    return ratio_intersection, topk_neighbors_x, topk_neighbors_y

def jaccard_set(list1, list2):
    """Define Jaccard Similarity function for two sets"""
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union

def get_jaccard_intersection(relationship_points_x, relationship_points_y, nb_neighbors, largest=True):
    topk_neighbors_x = get_topk_intersecting_neighbors(
        relationship_points_x, nb_neighbors, largest=largest)
    topk_neighbors_y = get_topk_intersecting_neighbors(
        relationship_points_y, nb_neighbors, largest=largest)
    ratio_intersection = []
    for neighbors_x, neighbors_y in zip(topk_neighbors_x, topk_neighbors_y):
        jaccard_similarity = jaccard_set(neighbors_x.tolist(), neighbors_y.tolist())
        ratio_intersection.append(jaccard_similarity)
    ratio_intersection = np.array(ratio_intersection)
    return ratio_intersection, topk_neighbors_x, topk_neighbors_y

def merge_perturbed_notperturbed_preds(preds, perturbed_preds, perturbation_idxs=None):
    if perturbation_idxs is None:
        return preds
    all_preds = []
    for idx, (pred, pert_pred) in enumerate(zip(preds, perturbed_preds)):
        if idx in perturbation_idxs:
            all_preds.append(pert_pred)
        else:
            all_preds.append(pred)
    all_preds = torch.stack(all_preds)
    return all_preds

def get_percentage(array):
    return array.sum() / array.shape

def compare_equality(array1, array2, array3=None):
    if array3 is not None:
        return (np.equal(array1, array2) &  np.equal(array1, array3))
    return np.equal(array1, array2)

def get_relative_ranking(sorted_df1, sorted_df2):
    '''Create ranking of sorted_df2 relative to sorted_df1
    Both df should be sorted in an descending order'''
    ranking_top1, ranking_top2 = [], []

    for sorted_index1, index1 in enumerate(sorted_df1.index):
        ranking_top1.append(sorted_index1)
        ranking_top2.append(sorted_df2.index.get_loc(index1))

    return np.array(ranking_top1), np.array(ranking_top2)

def sort_dataframe(df, column_name, ascending):
    return df.sort_values(column_name, ascending=ascending)

def create_df(preds=None, labels=None, sim=None, confidence=None):
    '''Create dataframe from:
        - preds: dict
        - labels: list
        - sim: list
    '''
    data = {}
    if preds is not None:
        for name, pred in preds.items():
            data[f'{name}'] = pred
            if labels is not None:
                data[f'correct_{name}'] = torch.where(pred == labels, 1, 0)
    if labels is not None:
        data['labels'] = labels
    if sim is not None:
        data['sim'] = sim
    if confidence is not None:
        data['confidence'] = confidence
    return pd.DataFrame(data)

def df_calculate_accuracy(data, preds_columns, labels_column_name='labels'):
    result = {}
    for preds_column in preds_columns:
        result[preds_column] = get_percentage(
            compare_equality(data[preds_column], data[labels_column_name]))
    return result

def df_calculate_consistency(data, preds_columns, preds_column_name=''):
    '''Calculates consistency for dataframe'''
    combinations = []
    len_preds_columns = len(preds_columns)
    if len_preds_columns < 2:
        print('Number of prediction columns must be > 2')
        exit()
    elif len_preds_columns > 2:
        for idx in reversed(range(len_preds_columns+1)[2:]):
            combinations.extend(
                list(itertools.combinations(list(range(len_preds_columns)), idx)))
    else:
        combinations.append([0,1])

    results = {}
    for combination in combinations:
        data_analysis, names = [], []
        for idx_comb in combination:
            name = preds_columns[idx_comb]
            names.append(name)
            data_analysis.append(data[name].values)
        results['_'.join(list(map(str, names)))] = get_percentage(compare_equality(*data_analysis))
    return results
