import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import linregress
from scipy import linalg

class MetaLearnProb(object):
  def __init__(self, d, r, t, sigma=None, homogeneity=None):
    self.d = d
    self.r = r
    self.t = t

    assert d >= r
    self.U = np.eye(d)[:, :r]
    self.V = np.random.normal(size=[t, r])

    self.homogeneity = homogeneity
    if self.homogeneity is not None:
      # import ipdb; ipdb.set_trace()

      assert 0.0 <= self.homogeneity and self.homogeneity <= 1.0

      nof_duplicates = int(self.homogeneity*self.t)
      duplicate_norm = r**0.5
      duplicate_idx = np.random.choice(self.t, size=nof_duplicates)
      duplicate_vector = np.zeros([r])
      duplicate_vector[0] = duplicate_norm
      self.V[duplicate_idx] = duplicate_vector

    self.sigma = 0 if sigma is None else sigma

    self.m = None
    self.X = None
    self.noise = None
    self.S = None
    self.Z = None
    self.moment = None

    self.m_val = None
    self.X_val = None
    self.noise_val = None
    self.S_val = None
    self.Z_val = None

  def generate_data(self, m, noise=True, dist_x=None, dist_y=None):
    dist_x = 'normal' if dist_x is None else dist_x
    dist_y = 'normal' if dist_y is None else dist_y

    self.m = m

    if dist_x == 'normal':
      self.X = np.random.normal(
        size=[self.t, self.m, self.d])
    elif dist_x == 'exp':
      self.X = np.random.standard_exponential(
        size=[self.t, self.m, self.d])
    elif dist_x == 'laplace':
      self.X = np.random.laplace(
        scale=(2.0)**-0.5, size=[self.t, self.m, self.d])
    else:
      raise ValueError('Unknown distribution={}!'.format(dist_x))

    if noise or self.sigma == 0:
      if dist_y == 'normal':
        self.noise = self.sigma*np.random.normal(
          size=[self.t, self.m])
      elif dist_y == 'exp':
        self.noise = self.sigma*np.random.standard_exponential(
          size=[self.t, self.m])
      elif dist_y == 'laplace':
        self.noise = np.random.laplace(
          scale=(2.0)**-0.5, size=[self.t, self.m])
      else:
        raise ValueError('Unknown distribution={}!'.format(dist_y))
    else:
      self.noise = np.zeros(shape=[1, 1])
    
    return self.X, self.noise

  def generate_val_data(self, m_val, noise=True):
    self.m_val = m_val
    self.X_val = np.random.normal(
      size=[self.t, self.m_val, self.d])

    if noise or self.sigma == 0:
      self.noise_val = self.sigma*np.random.normal(
        size=[self.t, self.m_val])
    else:
      self.noise_val = np.zeros(shape=[1, 1])
    
    return self.X_val, self.noise_val
    
  def get_altmin_data(self):
    self.S = np.einsum('ijk,ijl->ikl', self.X, self.X)/self.m
    self.Z = np.einsum('ij,ijk->ik', self.noise, self.X)/self.m
    return self.S, self.Z

  def get_method_of_moments_data(self):
    y = np.einsum(
      'ir,dr,ijd->ij', self.V, self.U, self.X)
    y = y + self.noise
    self.moment = np.einsum(
      'ij,ij,ijk,ijl->kl', y, y, self.X, self.X)/self.m/self.t

    return self.moment

  def get_2nd_ord_method_of_moments_data(self):
    y = np.einsum(
      'ir,dr,ijd->ij', self.V, self.U, self.X)
    y = y + self.noise

    second_moment = np.einsum(
      'ij,ijd->id', y, self.X)/self.m

    second_ord_moment = np.einsum(
      'ik,il->ikl', second_moment, second_moment)
    second_ord_self_moment = np.einsum(
      'ij,ij,ijk,ijl->ikl', y, y, self.X, self.X)/self.m

    self.second_ord_moment = (
      (self.m/(self.m-1))*second_ord_moment - second_ord_self_moment/(self.m-1)).mean(axis=0)
    return self.second_ord_moment

  def mse_loss(self, U, V, average=None):
    y = np.einsum(
      'ir,dr,ijd->ij', self.V, self.U, self.X)
    y = y + self.noise
    y_est = np.einsum(
      'ir,dr,ijd->ij', V, U, self.X)
    if average is None or average:
      return ((y - y_est)**2).mean()
    else:
      return ((y - y_est)**2).mean(axis=1)

  def avg_mse_loss_val(self, U, V, average=None):
    regressors = np.einsum(
      'ir,dr->id', self.V, self.U)
    regressors_est = np.einsum(
      'ir,dr->id', V, U)
    return ((regressors - regressors_est)**2).mean()


def operator_update_U(S, V, d, r):
  operator = np.einsum('ijk,ir,il->jrkl', S, V, V).reshape(d*r,d*r)
  return operator

def update_U(V, U_opt, V_opt, S, Z=None):
    d, r = S.shape[1], V_opt.shape[1]
    U_hat = np.einsum('ijk,kr,ir,il->jl', S, U_opt, V_opt, V)
    if Z is not None:
        U_hat += np.einsum('ij,ik->jk', Z, V)

    # operator = np.einsum('ijk,ir,il->jrkl', S, V, V).reshape(d*r,d*r)
    operator = operator_update_U(S, V, d, r)
    operator_inv = np.linalg.inv(operator)
    U_hat = operator_inv.dot(U_hat.reshape(d*r)).reshape(d, r)
    U, R = np.linalg.qr(U_hat, mode='reduced')

    return U

def update_V(U, U_opt, V_opt, S, Z=None):
    V = np.einsum('jk,ijl,lr,ir->ik', U, S, U_opt, V_opt)
    if Z is not None:
        V += np.einsum('jk,ij->ik', U, Z)

    operators = np.einsum('jk,ijl,lr->ikr', U, S, U)
    operator_invs = np.zeros_like(operators)
    for i in range(operators.shape[0]):
        operator_invs[i] = np.linalg.inv(operators[i])

    V = np.einsum('ikr,ir->ik', operator_invs, V)

    return V

def gradients_UV(U, V, U_opt, V_opt, S, Z=None):
    d, r, t = S.shape[1], V_opt.shape[1], V_opt.shape[0]
    regs_diff = U.dot(V.T) - U_opt.dot(V_opt.T)
    grad_U = np.einsum('idk,ki,ir->dr', S, regs_diff, V)/t
    grad_V = np.einsum('idk,ki,dr->ir', S, regs_diff, U)
    if Z is not None:
        grad_U = grad_U + np.einsum('id,ir->dr', Z, V)/t
        grad_V = grad_V + np.einsum('id,dr->ir', Z, U)

    return grad_U, grad_V

def gradient_U(U, V, U_opt, V_opt, S, Z=None):
    d, r, t = S.shape[1], V_opt.shape[1], V_opt.shape[0]
    regs_diff = U.dot(V.T) - U_opt.dot(V_opt.T)
    grad_U = np.einsum('idk,ki,ir->dr', S, regs_diff, V)/t
    if Z is not None:
        grad_U = grad_U + np.einsum('id,ir->dr', Z, V)/t

    return grad_U

def gradient_V(U, V, U_opt, V_opt, S, Z=None):
    d, r, t = S.shape[1], V_opt.shape[1], V_opt.shape[0]
    regs_diff = U.dot(V.T) - U_opt.dot(V_opt.T)
    grad_V = np.einsum('idk,ki,dr->ir', S, regs_diff, U)
    if Z is not None:
        grad_V = grad_V + np.einsum('id,dr->ir', Z, U)

    return grad_V

def distance_U(U, U_ref):
    r = U.shape[1]
    return np.linalg.norm(
        U - U_ref.dot(U_ref.T.dot(U))
        )/(r**0.5)

def distance_U_spectral(U, U_ref):
    return linalg.svdvals(
        U - U_ref.dot(U_ref.T.dot(U)))[0]

####
def apply_alt_min(
  prob, N_step, U_init=None,
  partition=False, 
  init_mom=False,
  # sample_split=False, subset=False
  ):
  if U_init is None:
    if init_mom:
      U_init = apply_method_of_moments(prob)['U']
    else:
      U_init = np.linalg.qr(np.random.normal(
        size=[prob.d, prob.d]))[0][:, :prob.r]
  S, Z = prob.get_altmin_data()

  t = S.shape[0]
  if N_step <= t and partition:
    partitions = np.array_split(np.random.permutation(t), N_step)
  else:
    partitions = [slice(None)]*N_step

  U = U_init
  U_list = [U]
  dist_U_list = [distance_U(U, prob.U)]
  dist_U_spectral_list = [distance_U_spectral(U, prob.U)]
  avg_mse_loss_list = []
  for step in range(N_step):
      V = update_V(
        U, prob.U, prob.V, 
        S=S, Z=Z)
      avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
      U = update_U(
        V[partitions[step]], prob.U, prob.V[partitions[step]], 
        S=S[partitions[step]], Z=Z[partitions[step]])
      U_list.append(U)
      dist_U_list.append(distance_U(U, prob.U))
      dist_U_spectral_list.append(distance_U_spectral(U, prob.U))

  V = update_V(U, prob.U, prob.V, S=S, Z=Z)
  avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
  output = {
    'U_init': U_init,
    'U_list': U_list,
    'dist_U_list': dist_U_list,
    'dist_U_spectral_list': dist_U_spectral_list,
    'avg_mse_loss_list': avg_mse_loss_list,
    'U': U,
    }
  return output

####
def apply_method_of_moments(prob):
  moment = prob.get_method_of_moments_data()
  
  U_list = []
  dist_U_list = []
  dist_U_spectral_list = []
  avg_mse_loss_list = []

  _, U = linalg.eigh(moment, eigvals=(prob.d - prob.r, prob.d-1))

  U_list.append(U)
  dist_U_list.append(distance_U(U, prob.U))
  dist_U_spectral_list.append(distance_U_spectral(U, prob.U))
  S, Z = prob.get_altmin_data()
  V = update_V(U, prob.U, prob.V, S=S, Z=Z)
  avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
      

  output = {
    'dist_U_list': dist_U_list,
    'dist_U_spectral_list': dist_U_spectral_list,
    'avg_mse_loss_list': avg_mse_loss_list,
    'U': U,
    }
  return output

####
def apply_2nd_ord_method_of_moments(prob):
  second_ord_moment = prob.get_2nd_ord_method_of_moments_data()
  
  U_list = []
  dist_U_list = []
  dist_U_spectral_list = []
  avg_mse_loss_list = []

  _, U = linalg.eigh(second_ord_moment, eigvals=(prob.d - prob.r, prob.d-1))

  U_list.append(U)
  dist_U_list.append(distance_U(U, prob.U))
  dist_U_spectral_list.append(distance_U_spectral(U, prob.U))
  S, Z = prob.get_altmin_data()
  V = update_V(U, prob.U, prob.V, S=S, Z=Z)
  avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
      

  output = {
    'dist_U_list': dist_U_list,
    'dist_U_spectral_list': dist_U_spectral_list,
    'avg_mse_loss_list': avg_mse_loss_list,
    'U': U,
    }
  return output

####
def apply_grad_descent(
  prob, N_step, 
  stepsize,
  regularizer=False,
  U_init=None,
  # partition=False, 
  init_mom=False,
  # sample_split=False, subset=False
  qr_decomp=True,
  ):
  if U_init is None:
    if init_mom:
      U_init = apply_method_of_moments(prob)['U']
    else:
      U_init = np.linalg.qr(np.random.normal(
        size=[prob.d, prob.d]))[0][:, :prob.r]
  S, Z = prob.get_altmin_data()

  t = S.shape[0]
  # if N_step <= t and partition:
  #   partitions = np.array_split(np.random.permutation(t), N_step)
  # else:
  #   partitions = [slice(None)]*N_step

  U = U_init
  U_list = [U]
  dist_U_list = [distance_U(U, prob.U)]
  grad_U_norm_list = []
  grad_V_norm_list = []
  dist_U_spectral_list = [distance_U_spectral(U, prob.U)]
  avg_mse_loss_list = []
  V = update_V(
    U, prob.U, prob.V, 
    S=S, Z=Z)
  for step in range(N_step):
      avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
      grad_U, grad_V = gradients_UV(
        U, V, prob.U, prob.V, 
        S=S, Z=Z)
      grad_U_norm_list.append((np.sum(grad_U**2))**0.5)
      grad_V_norm_list.append((np.sum(grad_V**2))**0.5)
      if regularizer:
        # grad_U = grad_U + 
        raise NotImplementedError('regularizer not implemented')
      U = U - stepsize*grad_U
      V = V - stepsize*grad_V
      _U, R = np.linalg.qr(U, mode='reduced')
      if qr_decomp:
        U = _U
      U_list.append(_U)
      dist_U_list.append(distance_U(_U, prob.U))
      dist_U_spectral_list.append(distance_U_spectral(_U, prob.U))

  avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
  output = {
    'U_init': U_init,
    'U_list': U_list,
    'dist_U_list': dist_U_list,
    'dist_U_spectral_list': dist_U_spectral_list,
    'avg_mse_loss_list': avg_mse_loss_list,
    'U': U,
    'grad_U_norm_list': grad_U_norm_list,
    'grad_V_norm_list': grad_V_norm_list,
    }
  return output

####
def apply_alt_min_gd(
  prob, N_step, U_init=None,
  partition=False, 
  init_mom=False,
  # sample_split=False, subset=False
  U_gd=True, stepsize=None,
  ):
  if U_init is None:
    if init_mom:
      U_init = apply_method_of_moments(prob)['U']
    else:
      U_init = np.linalg.qr(np.random.normal(
        size=[prob.d, prob.d]))[0][:, :prob.r]
  S, Z = prob.get_altmin_data()

  t = S.shape[0]
  if N_step <= t and partition:
    partitions = np.array_split(np.random.permutation(t), N_step)
  else:
    partitions = [slice(None)]*N_step

  U = U_init
  U_list = [U]
  dist_U_list = [distance_U(U, prob.U)]
  dist_U_spectral_list = [distance_U_spectral(U, prob.U)]
  avg_mse_loss_list = []
  grad_U_norm_list = []
  stepsize_list = []

  # import ipdb; ipdb.set_trace()
  for step in range(N_step):
      # if U_gd:
      #     if stepsize is None:
      #         W = V[partitions[step]].T.dot(V[partitions[step]])
      #         W_eig_max = linalg.eigh(W, eigvals_only=True, eigvals=(W.shape[0]-1, W.shape[0]-1))
      #         _stepsize = 1/W_eig_max
      #     else:
      #       _stepsize = stepsize
      #     stepsize_list.append(_stepsize)

      #     grad_U = gradient_U(U, V[partitions[step]], prob.U, prob.V[partitions[step]], S=S[partitions[step]], Z=Z[partitions[step]])
      #     grad_U_norm_list.append((np.sum(grad_U**2))**0.5)
      #     U = U - _stepsize*grad_U
      #     U, R = np.linalg.qr(U, mode='reduced')
      # else:
      #     U = update_U(
      #       V, prob.U, prob.V[partitions[step]], 
      #       S=S[partitions[step]], Z=Z[partitions[step]])               
      #     V = update_V(
      #       U, prob.U, prob.V,
      #       S=S, Z=Z)

      V = update_V(
        U, prob.U, prob.V,
        S=S, Z=Z)

      avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))

      if U_gd:
          if stepsize is None:
              W = V[partitions[step]].T.dot(V[partitions[step]])
              W_eig_max = linalg.eigh(W, eigvals_only=True, eigvals=(W.shape[0]-1, W.shape[0]-1))
              _stepsize = 1/W_eig_max
          else:
            _stepsize = stepsize
          stepsize_list.append(_stepsize)

          grad_U = gradient_U(U, V[partitions[step]], prob.U, prob.V[partitions[step]], S=S[partitions[step]], Z=Z[partitions[step]])
          grad_U_norm_list.append((np.sum(grad_U**2))**0.5)
          U = U - _stepsize*grad_U
          U, R = np.linalg.qr(U, mode='reduced')
      else:
          U = update_U(
            V, prob.U, prob.V[partitions[step]], 
            S=S[partitions[step]], Z=Z[partitions[step]])

      U_list.append(U)
      dist_U_list.append(distance_U(U, prob.U))
      dist_U_spectral_list.append(distance_U_spectral(U, prob.U))

  V = update_V(U, prob.U, prob.V, S=S, Z=Z)
  avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
  output = {
    'U_init': U_init,
    'U_list': U_list,
    'dist_U_list': dist_U_list,
    'dist_U_spectral_list': dist_U_spectral_list,
    'avg_mse_loss_list': avg_mse_loss_list,
    'U': U,
    'grad_U_norm_list': grad_U_norm_list,
    'stepsize_list': stepsize_list,
    }
  return output


  prob, N_step, U_init=None,
  partition=False, 
  init_mom=False,
  # sample_split=False, subset=False
  mode=None, U_stepsize=None, V_stepsize=None,
  ):
  if mode is None:
    mode = 'altmin'

  if mode == 'altmin':
    U_gd, V_gd = False, False
  elif mode == 'altmingd':
    U_gd, V_gd = True, False
  elif mode == 'altgdgd':
    if partition:
      raise ValueError('cannot partition tasks and do GD on V')
    U_gd, V_gd = True, True
  # elif mode == 'simulgd':
  #   U_gd, V_gd = True, True  
  else:
    raise ValueError  

  if U_init is None:
    if init_mom:
      U_init = apply_method_of_moments(prob)['U']
    else:
      U_init = np.linalg.qr(np.random.normal(
        size=[prob.d, prob.d]))[0][:, :prob.r]
  S, Z = prob.get_altmin_data()

  t = S.shape[0]
  if N_step <= t and partition:
    partitions = np.array_split(np.random.permutation(t), N_step)
  else:
    partitions = [slice(None)]*N_step

  U = U_init
  U_list = [U]
  dist_U_list = [distance_U(U, prob.U)]
  dist_U_spectral_list = [distance_U_spectral(U, prob.U)]
  avg_mse_loss_list = []
  grad_U_norm_list = []
  U_stepsize_list = []
  V_stepsize_list = []

  # import ipdb; ipdb.set_trace()

  V = update_V(
    U, prob.U, prob.V,
    S=S, Z=Z)
  for step in range(N_step):
      if V_gd:
          if partition:
                raise ValueError('cannot partition tasks and do GD on V')

          V_best = update_V(
            U, prob.U, prob.V,
            S=S, Z=Z)

          if V_stepsize is None:
              raise NotImplementedError
          else:
            _V_stepsize = V_stepsize
          V_stepsize_list.append(_V_stepsize)

          grad_V = gradient_V(U, V[partitions[step]], prob.U, prob.V[partitions[step]], S=S[partitions[step]], Z=Z[partitions[step]])
          grad_V_norm_list.append((np.sum(grad_V**2))**0.5)
          V = V - _V_stepsize*grad_V
      else:
          V = update_V(
            U, prob.U, prob.V,
            S=S, Z=Z)
          V_best = V

      avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V_best))

      if U_gd:
          if U_stepsize is None:
              W = V[partitions[step]].T.dot(V[partitions[step]])
              W_eig_max = linalg.eigh(W, eigvals_only=True, eigvals=(W.shape[0]-1, W.shape[0]-1))
              _U_stepsize = 1/W_eig_max
          else:
            _U_stepsize = U_stepsize
          U_stepsize_list.append(_U_stepsize)

          grad_U = gradient_U(U, V[partitions[step]], prob.U, prob.V[partitions[step]], S=S[partitions[step]], Z=Z[partitions[step]])
          grad_U_norm_list.append((np.sum(grad_U**2))**0.5)
          U = U - _U_stepsize*grad_U
          U, R = np.linalg.qr(U, mode='reduced')
      else:
          U = update_U(
            V, prob.U, prob.V[partitions[step]], 
            S=S[partitions[step]], Z=Z[partitions[step]])

      U_list.append(U)
      dist_U_list.append(distance_U(U, prob.U))
      dist_U_spectral_list.append(distance_U_spectral(U, prob.U))

  V = update_V(U, prob.U, prob.V, S=S, Z=Z)
  avg_mse_loss_list.append(prob.avg_mse_loss_val(U, V))
  output = {
    'U_init': U_init,
    'U_list': U_list,
    'dist_U_list': dist_U_list,
    'dist_U_spectral_list': dist_U_spectral_list,
    'avg_mse_loss_list': avg_mse_loss_list,
    'U': U,
    'grad_U_norm_list': grad_U_norm_list,
    'U_stepsize_list': U_stepsize_list,
    'V_stepsize_list': V_stepsize_list,
    }
  return output  