import warnings

import numpy as np
import scipy as sp
import scipy.spatial
import scipy.stats

import utils


def normalise_ratings(ratings, offer, consumers):
  norm = lambda v, **kwargs: np.linalg.norm(v, ord=2, axis=-1, **kwargs)
  return ratings / (norm(consumers)[:, None] * norm(offer)[None])


def group_mean_and_median(arr, group_ids):
  means, medians = [], []
  for idx in group_ids:
    means.append(np.mean(arr[idx]))
    medians.append(np.median(arr[idx]))
  return np.array(means), np.array(medians)


def _log_mean_median(stats, arr, label, group_ids):
  ids = tuple(group_ids.values())

  means, medians = group_mean_and_median(arr, ids)
  for gid, mean, median in zip(group_ids, means, medians):
    stats[f'avg_{label}_{gid}'] = mean
    stats[f'med_{label}_{gid}'] = median


def _consumer_stats_ml100k(group_ids, ratings, probs, distances):
  stats = {}

  def _log_stats(arr, label):
    _log_mean_median(
      stats=stats, arr=arr, label=label, group_ids=group_ids)

  _log_stats(np.min(distances, axis=-1), label='min_dist')
  _log_stats(np.max(ratings, axis=-1), label='max_rating')
  _log_stats(np.mean(ratings, axis=-1), label='mean_rating')
  _log_stats(np.median(ratings, axis=-1), label='median_rating')
  _log_stats(sp.stats.entropy(probs, axis=-1), label='rec_entropy')

  return stats


def _producer_stats_ml100k(group_info, group_ids, ratings, distances, label):
  stats = {}
  ids = tuple(group_ids.values())

  nearest = group_info[np.argmin(distances, axis=0)]
  best_rated = group_info[np.argmax(ratings, axis=0)]
  means = np.vstack([np.mean(ratings[idx], axis=0) for idx in ids])
  medians = np.vstack([np.median(ratings[idx], axis=0) for idx in ids])
  for gid, mean, median, in zip(group_ids.keys(), means, medians):
    stats[f'prop_{label}_mean_best_{gid}'] = np.mean(mean >= means.max(0))
    stats[f'prop_{label}_median_best_{gid}'] = np.mean(median >= medians.max(0))
    stats[f'prop_{label}_nearest_{gid}'] = np.mean(nearest == gid)
    stats[f'prop_{label}_best_rated_{gid}'] = np.mean(best_rated == gid)

  return stats


def all_stats_ml100k(
    group_info, offer, consumers, ratings, probs, label, normalise, **kwargs):
  metric = 'cosine' if normalise else 'euclidean'
  ratings, probs = np.array(ratings), np.array(probs)
  distances = sp.spatial.distance.cdist(consumers, offer, metric=metric)
  if normalise:
    ratings = normalise_ratings(ratings, offer=offer, consumers=consumers)
  group_ids = {k: group_info == k for k in np.unique(group_info)}

  stats = {
    f'rating_min_{label}': np.min(ratings),
    f'rating_max_{label}': np.max(ratings),
    f'rating_mean_{label}': np.mean(ratings),
    f'rating_std_{label}': np.std(ratings)
  }

  stats.update(_consumer_stats_ml100k(
    group_ids=group_ids, ratings=ratings, probs=probs, distances=distances))
  stats.update(_producer_stats_ml100k(
    group_info=group_info, group_ids=group_ids, ratings=ratings,
    distances=distances, label=label))

  for gid in np.unique(group_info):
    stats[f'consumer_norm_{gid}'] = np.mean(np.sum(
      consumers[group_info == gid]**2, axis=-1)).item()

  return stats


def _iter_stats_lastfm360_neigh(n_neigh, distances, group_info, group_ids):
  nearest_ids = np.argsort(distances, axis=-1)[:, :n_neigh]
  nearest_gender = group_info[nearest_ids]
  nearest_distances = np.take_along_axis(distances, nearest_ids, axis=-1)

  ret = {}
  with warnings.catch_warnings():  # np.nan<fn> complain about empty slices
    warnings.simplefilter('ignore', category=RuntimeWarning)
    for gid in group_ids.keys():
      suff = f'{gid}_{n_neigh}'
      ret[f'neigh_prop_{suff}'] = np.nanmean(nearest_gender == gid)
      ret[f'neigh_avg_dist_{suff}'] = np.nanmean(
        nearest_distances[nearest_gender == gid])
      ret[f'neigh_med_dist_{suff}'] = np.nanmedian(
        nearest_distances[nearest_gender == gid])

  return ret


def iter_stats_lastfm360(
    group_info, offer, offer_util, base_offer, normalise, **kwargs):
  stats = {}
  metric = 'cosine' if normalise else 'euclidean'
  group_ids = {k: group_info == k for k in np.unique(group_info)}
  distances = sp.spatial.distance.cdist(offer, base_offer, metric=metric)

  def _log_stats(arr, label):
    _log_mean_median(
      stats=stats, arr=arr, label=label, group_ids=group_ids)

  _log_stats(np.min(distances, 0), label=f'min_dist')
  _log_stats(np.mean(distances, 0), label=f'mean_dist')
  _log_stats(np.median(distances, 0), label=f'median_dist')
  _log_stats(offer_util[np.argmin(distances, 0)], label=f'near_util')

  for gid in group_ids.keys():
    stats[f'pop_prop_{gid}'] = np.mean(group_ids[gid])
    stats[f'pop_avg_norm_{gid}'] = np.mean(
      np.sum(base_offer[group_ids[gid]]**2, -1)).item()

  for n_neigh in (1, 5, 10, 25, 50):
    stats.update(_iter_stats_lastfm360_neigh(
      n_neigh, distances=distances, group_info=group_info, group_ids=group_ids))

  return stats


def base_stats_lastfm360(offer, ratings, label, **kwargs):
  stats = {
    'base_cluster_count': len(
      utils.recursive_cluster(offer, max_iters=50, tol=1e-5)),
    f'rating_min_{label}': np.min(ratings),
    f'rating_max_{label}': np.max(ratings),
    f'rating_mean_{label}': np.mean(ratings),
    f'rating_std_{label}': np.std(ratings)
  }

  return stats
