import os
import random

import numpy as np
import scipy as sp 
from scipy.cluster.hierarchy import fcluster
from scipy.cluster.hierarchy import linkage
import scipy.spatial.distance as ssd

from src import data


def batches_count_each_type(m, k):
  """
  For a sample size of m and k cluster components,
  returns an array of length k with entries
  floor(m/k) + c_k,
  where c_k = 0 or 1 such that the sum of the
  above entries equals m.
  """
  min_batches_each_type = int(m/k)
  batch_counts = [min_batches_each_type] * k
  remaining = m - min_batches_each_type * k
  for i in range(remaining):
      batch_counts[i] +=1
  return batch_counts


def generate_batches(W, n, m, noise_level):
  """
  Takes a num_components x dimension array W where the ith
  row is the weight for component i, i = 0,..,k-1.

  m is the total number of batches
  n is the length of each batch
  """
  k, d = np.shape(W)
  batch_counts = batches_count_each_type(m, k)
  X, Y = [], []
  for i in range(k):
      batch_count = batch_counts[i]
      w = W[i].reshape(d, 1)
      covariates = np.random.randn(batch_count * n, d)
      response = covariates @ w + np.sqrt(noise_level) * np.random.randn(batch_count * n, 1)
      X.append(covariates.reshape(-1, n, d))
      Y.append(response.reshape(-1, n))
  X = np.concatenate(X, axis=0)
  Y = np.concatenate(Y, axis=0)
  return (X, Y)


def get_estimates_prior(B, P, k, T=1):
    # Copy of the implementation of algorithm in Prior work KSS+20
    X, Y = B
    m, n, d = np.shape(X)
    HH = np.zeros((T, m, m))
    start = 0
    for l in range(T):
        end = (l + 1) * n // T
        Xl = X[:, start:end, :]
        Yl = Y[:, start:end]
        n1 = (end - start) // 2
        start = end
        clipped_grads_a = calculate_clipped_grad(Xl[:, :n1, :], Yl[:, :n1], np.inf, np.zeros((d, 1)))
        proj_a = np.matmul(clipped_grads_a, P)
        clipped_grads_b = calculate_clipped_grad(Xl[:, n1:, :], Yl[:, n1:], np.inf, np.zeros((d, 1)))
        proj_b = np.matmul(clipped_grads_b, P)
        v = np.sum(proj_a * proj_b, axis=1)
        HH[l] -= np.dot(proj_a, proj_b.T)
        HH[l] += v
        HH[l] += HH[l].T
    if T > 1:
        H = np.median(HH, axis=0)
    else:
        H = HH[0]
    np.fill_diagonal(H, 0, wrap=False)
    Z = linkage(ssd.squareform(np.abs(H)), method="average")
    Clusters = fcluster(Z, k, criterion='maxclust') - 1
    grads = {}
    for i in range(m):
        grad = calculate_clipped_grad(X[i:i + 1, :, :], Y[i:i + 1, :], np.inf, np.zeros((d, 1)))
        proj = np.matmul(grad, P)
        if Clusters[i] not in grads:
            grads[Clusters[i]] = [proj, 1]
        else:
            grads[Clusters[i]][0] = grads[Clusters[i]][1] * grads[Clusters[i]][0] + proj
            grads[Clusters[i]][1] += 1
            grads[Clusters[i]][0] = grads[Clusters[i]][0] / grads[Clusters[i]][1]
    L = []
    for item in grads:
        L.append(-1 * np.array(grads[item][0]).reshape(-1, 1))
    return L


def calculate_clipped_grad(X, Y, kappa, w):
    m, n, d = X.shape
    w = w.reshape(d,1)
    X = X.reshape(m * n, d)
    Y_hat = np.dot(X, w).reshape(m * n, 1)
    diff = Y_hat - Y.reshape(m * n, 1)
    if kappa != np.inf:
        clipped_diff = np.divide(kappa * diff, np.maximum(kappa, np.abs(diff)))
    else:
        clipped_diff = diff
    clipped_grads = np.multiply(X, clipped_diff).reshape(m, n, d)
    clipped_grads = np.sum(clipped_grads, axis=1) / n
    clipped_grads = clipped_grads.reshape(m, d)
    return clipped_grads


def estimate_mse(bspe, w):
    X, Y = bspe
    m, n, d = np.shape(X)
    X = X.reshape(n, d)
    Y = Y.reshape(n, 1)
    dif = Y - np.matmul(X, w.reshape(d, 1))
    mse = np.matmul(dif.T, dif) / n
    return mse[0][0]


def clipping_estimate(bspe, w, sigma):
    mse = estimate_mse(bspe, w)
    kappa = np.sqrt(2 * (mse + sigma * sigma))
    return kappa


def grad_est(B, bspe, kappa, w, P, weights, epsilon=0.2):
    data = [B, bspe]
    clipped_grad_proj = []
    for batches in data:
        X, Y = batches
        m, n, d = np.shape(X)
        n1 = n // 2
        clipped_grads_a = calculate_clipped_grad(X[:, :n1, :], Y[:, :n1], kappa, w)
        proj_a = np.matmul(clipped_grads_a, P)
        clipped_grads_b = calculate_clipped_grad(X[:, n1:, :], Y[:, n1:], kappa, w)
        proj_b = np.matmul(clipped_grads_b, P)

        clipped_grad_proj.append((proj_a, proj_b))

    proj_all_a, proj_all_b = clipped_grad_proj[0]
    proj_spe_a, proj_spe_b = clipped_grad_proj[1]

    diff_a = proj_all_a - proj_spe_a
    diff_b = proj_all_b - proj_spe_b
    diff_norm = np.sum(np.multiply(diff_a, diff_b), axis=1)
    spe_norm = np.sum(np.multiply(proj_spe_a, proj_spe_b), axis=1)
    kept_batches = diff_norm < epsilon * spe_norm
    m, n = np.shape(kept_batches.reshape(-1, 1))
    for i in range(m):
        if kept_batches[i] == False:
            weights[i] *= 0.1
    grad = np.matmul(weights, 0.5 * (proj_all_a + proj_all_b)) / np.sum(weights)
    grad = grad.reshape(-1, 1)
    return grad


def subspace_estimation(B, ell, kappa, w):
    X, Y = B
    m, n, d = np.shape(X)
    n1 = n // 2
    clipped_grads_a = calculate_clipped_grad(X[:, :n1, :], Y[:, :n1], kappa, w)
    clipped_grads_b = calculate_clipped_grad(X[:, n1:, :], Y[:, n1:], kappa, w)
    A = np.matmul(clipped_grads_a.T, clipped_grads_b) + np.matmul(clipped_grads_b.T, clipped_grads_a)
    e, U = sp.linalg.eigh(A, subset_by_index=(max(d - ell, 0), d - 1))
    return np.matmul(U, U.T)


def main_algo_single_comp(BS, BM, bspe, sigma, ell, con_num, R, init_est, weights):
    # BS: small size batches, #BM: Medium size batches #bspe: A medium size batch of the distribution regression vector whose we wish to recover
    # sigma: Noise Variance, ell: size of the subspace, con_num: upper bound on condition number of covariance of X's
    # Initial estimate of  regression vector, Weights: 0-1 vector denoting for each batche if it haven been clustered or not in previous runs
    w = init_est
    for r in range(R):
        kappa = clipping_estimate(bspe, w, sigma)
        P = subspace_estimation(BS, ell, kappa, w)
        Delta = grad_est(BM, bspe, kappa, w, P, weights)
        w = w - 0.8 * Delta /con_num
    return w


def main_algo_multiple_comp(BS, BM, Bspe, sigma, ell, con_num, R, init_est):
    # Runs main_algo_single_comp for a collection of
    L = []
    X2, Y2 = BM
    m2, n2, d = X2.shape
    Xspe, Yspe = Bspe
    kspe, _, _ = Xspe.shape
    for i in range(kspe):
        weights = np.ones(m2)
        bspe_X = Xspe[i].reshape(1, n2, d)
        bspe_Y = Yspe[i].reshape(1, n2)
        bspe = (bspe_X, bspe_Y)
        w = main_algo_single_comp(BS, BM, bspe, sigma, ell, con_num, R, init_est, weights)
        L.append(w)
    return L


def main_algo_all_comp(BS, BM, sigma, ell, con_num, init_est, R):
    L = []
    X2, Y2 = BM
    m2, n2, d = X2.shape
    batches_wo_est = list(range(m2))
    weights = np.ones(m2)
    while len(batches_wo_est) >= 0.02 * m2 and len(L) <= m2/8:
        bspe_index = random.choice(batches_wo_est)
        bspe_X = X2[bspe_index].reshape(1, n2, d)
        bspe_Y = Y2[bspe_index].reshape(1, n2)
        bspe = (bspe_X, bspe_Y)
        remaining = weights.copy()
        remaining[bspe_index] = 0
        w = main_algo_single_comp(BS, BM, bspe, sigma, ell, con_num, R, init_est, remaining)
        L.append(w)
        new_wo_estimate = []
        weights = np.zeros(m2)
        for index in batches_wo_est:
            batch_X = X2[index].reshape(1, n2, d)
            batch_Y = Y2[index].reshape(1, n2)
            batch = (batch_X, batch_Y)
            mse = estimate_mse(batch, w)
            if mse > d/5:
                new_wo_estimate.append(index)
                weights[index] = 1
        batches_wo_est = new_wo_estimate
    return L

def main_algo_for_us(batches_one, batches_two,num_clusters, noise_level):
  dimension = batches_one[0].shape[2]
  constant = 10 # prefactor hidden in \Theta notation
  return np.array(main_algo_all_comp(BS=batches_one,
                            BM=batches_two,
                            sigma=np.sqrt(noise_level),
                            ell=num_clusters,
                            con_num=1.0,
                            init_est=np.zeros((dimension, 1)),
                            R=constant * int(np.ceil(np.log(dimension)))))[:, :, 0]

def generate_clusters(num_clusters, input_size):
  """
  Given num_clusters many clusters and input_size dimension,
  returns an array of shape (num_clusters, input_size)
  where each row is a uniform random vector of length input_size
  and of norm sqrt(input_size).
  """
  centers = np.random.randn(num_clusters, input_size)
  centers = (centers / np.linalg.norm(centers, axis=1)[:, None])
  return centers * np.sqrt(input_size)

def get_list_error(list_of_weights, correct_weights):
  """
  List of weights is a a num_list x dimension matrix of weights.
  Correct weights is a num_weights x dimension matrix of weights.
  We compute:
  max_{i in [num_weights]} min_{j in [num_list]}
    norm(correct_weights[i] - list_of_weights[j])^2/dimension
  This is the excess MSE (/dim) for the list provided above.
  """
  avg_dist = 0.0
  num_list = np.shape(list_of_weights)[0]
  num_weights = np.shape(correct_weights)[0]
  dimension = np.shape(correct_weights)[1]
  for i in range(num_weights):
    min_dist = np.Inf
    for j in range(num_list):
      dist = np.linalg.norm(correct_weights[i] - list_of_weights[j])
      min_dist = min(min_dist, dist)
    avg_dist += min_dist ** 2
  return avg_dist/ (dimension * num_weights)

def get_min_error(X, Y, Xt, Yt, clusters, test_batch_size, input_size):
  idx_min = np.argmin(np.linalg.norm((X @ clusters.T) - Y[:, :, np.newaxis],
                                     axis=1),
                      axis=1)
  all_preds = Xt @ clusters.T
  yp = np.array([all_preds[i, idx_min[i]] for i in range(test_batch_size)])
  return (np.linalg.norm(np.array(Yt) - yp) ** 2) / (test_batch_size * input_size)

def get_bayes_opt_error(X, Y, Xt, Yt, clusters, test_batch_size, input_size):
  weights = np.square(np.linalg.norm((X @ clusters.T) - Y[:, :, np.newaxis],
                                     axis=1))
  max_weights = np.min(weights, axis=1)
  weights = weights - max_weights[:, None]
  probs = np.exp(-0.5 * np.square(weights))
  probs = probs / np.sum(probs, axis=1)[:, None]
  w_hats = clusters.T @ probs.T
  yp = np.array([Xt[i, :] @ w_hats[:, i] for i in range(test_batch_size)])
  return (np.linalg.norm(np.array(Yt) - yp) ** 2) / (test_batch_size * input_size)

