from jax.tree_util import *
from jax.api import jvp
from jax.api import vjp
import lanczos as lanczos
from jax.api import eval_shape
from jax.api import jacobian
import jax.numpy as np
from jax import random, grad, jit
import numpy.random as npr
import os
import time
import pandas as pd
from data_util import minibatcher
import pickle
import math

# Some misc utils.

def save_weights(file_name, step, params):
    """Save params and step in appropiate format."""
    params_vec = flatten_jax(params)
    with open(file_name, 'wb') as f:
      pickle.dump(step, f)
      pickle.dump(params_vec, f)
    return

def load_weights(file_name, discarded_params):
    """Load time step and weights."""
    with open(file_name, 'rb') as f:
      step = pickle.load(f)
      params = unflatten_jax(pickle.load(f), discarded_params)
    return step, params



def get_jobpath(C,s0=0):

    name,path = '',C.logdir
    if C.logdir[-1]!='/':
        path+='/'
    Cdic = C._asdict()


    for el in C.job_dir_params.keys():
        val = str(Cdic[el])
        name += C.job_dir_params[el]+val+'_'

    try:
        bn=C.batch_norm
        if bn:
            name+='bn_'
    except:
        pass
    try:
        bn=C.unit_norm
        if bn:
            name+='1n_'
    except:
        pass
    if C.NTK_norm:
        name+='ntk_'
    else:
        name+='std_'

    if not C.job_id.startswith('largelrjob') and C.similarjobsfolder:
        # All same class of experiments at the same folder
        path+=name[:-1]+'/'

    for el in C.job_id_params.keys():
        if el=='learning_rate':
            val=('%f' % Cdic[el]).rstrip('0')
        else:
            val=Cdic[el]

        name+=C.job_id_params[el]+str(val)+'_'


    if int(s0)>0:
        name+='s0'+s0+'_'
    if C.linearize:
        name+='lin_'

    try:
        bn=C.batch_norm
        if not bn:
            name+='nobn_'
    except:
        pass
    if C.w_std!=math.sqrt(2.0):
        name+='w0{:.2}'.format(C.w_std)+'_'
    if C.opt=='momentum':
        name+='mom_'
    if C.momcos==True:
        name=name[:-1]+'cos_'
    if C.L2:
        name+='L2'+str(C.L2)+'_'
    if C.seed!=0:
        name+='seed'+str(C.seed)+'_'
    if C.randomlabels:
        name+='rndy'

    name=name[:-1]
    path+=name+'/'

    if not os.path.exists(path):
        print('Creating folder',path)
        os.makedirs(path)

    return path, name




def flatten_jax(v):
  "Flatten pytree to vector"
  v=tree_flatten(v)[0]
  v=[el.reshape(-1) for el in v]
  return np.concatenate(v)

def unflatten_jax(flat_tensor, orig_tensors):
  "Unflatten vector to pytree"
  orig_tensors,treedef=tree_flatten(orig_tensors)
  unflattened = []
  offset = 0
  for t in orig_tensors:
    num_elems = np.prod(t.shape)
    unflattened.append(np.reshape(flat_tensor[offset:offset + num_elems], t.shape))
    offset += num_elems
  return tree_unflatten(treedef,unflattened)


def meas_batch_generator(f,batch_gen,argnum=1,sample_size=4096):
  "Given a function with f(args1,batches,args2) ; this returns the measured averaged function f_meas(args1,args2). Can change the position of the argument. "
  "Evaluated for samples_size samples and the batch_size comes from generator. "
  batch=next(batch_gen)
  # This is needed because one can't jit if one has the dataset type args.
  # batch=(batch.X,batch.Y)
  batch_size=batch[0].shape[0]
  num_batches=sample_size // batch_size
  if sample_size<30:
    def new_f(*args):
      argss=args[:argnum]+(next(batch_gen),)+args[argnum:]
      return f(*argss)
      
  def new_f(*args):

      s=0
      batch=next(batch_gen)
      # batch=(batch.X,batch.Y)
      bs=batch[0].shape[0]
      argss=args[:argnum]+(batch,)+args[argnum:]
      fmeas=bs*f(*argss)
      s+=bs

      for _ in range(num_batches-1):
          batch=next(batch_gen)
          bs=batch[0].shape[0]
          argss=args[:argnum]+(batch,)+args[argnum:]
          s+=bs
          fmeas+=bs*f(*argss)
      return fmeas / s

  return new_f

# Functions related with the actual measurements
# All of the following functions return function which are evaluated at params and maybe a vector v.
# This means that their batch dependence is fixed when called and one should use a measure generator for that.

def gradfn(loss):
    return lambda params,batches: (flatten_jax(jit(grad(loss))(params,batches)))

def Hv(loss):
  "return H.v function given loss "
  @jit
  def Hvfn(params,batch,v):
      return flatten_jax(jvp(lambda p: grad(loss)(p,batch), (params,), (unflatten_jax(v,params),))[1])
  return Hvfn


def NTKv(f):
  @jit
  def NTKv_general(params,batch,v):
    """Returns matrix vector product for NTK(params, batch)."""
    sizey=np.prod(batch[1].shape)
    def delta_vjp(v):
      return vjp(lambda p: np.reshape(f(p, batch[0]),np.array([sizey]))/sizey, params)[1](v)
    return jvp(lambda p: np.reshape(f(p, batch[0]),np.array([sizey])), (params,), delta_vjp(v))[1]
  return NTKv_general


def trMvHutch(Mv, batches, sizev, tol=0.001):
  "Compute trM from M.v(v) using Hutchinson trick with N samples."
  key = random.PRNGKey(int(time.time()))

  def trHutchf(params,key=key):
    batch = next(batches)
    trest,prev_est=0,0
    N=0
    while True:
      key,subkey=random.split(key)
      v=random.normal(key,(sizev,))

      trest+=np.vdot(v,Mv(params,batch,v))
      N+=1
      if np.abs(prev_est-trest/N)<tol*prev_est:
        # print('Exited at {}'.format(N))
        break
      prev_est=trest/N

    return trest/N
  return trHutchf


def getHg_scalarsfn(gradd,Hv):
  "Given grad and H.v(v) return overlap and eigenvalue"
  def Hg_scalars(params):

    gradn=gradd(params)
    Hvn=Hv(params,gradn)
    dot=np.dot(Hvn,gradn)
    overlap=dot/np.linalg.norm(Hvn)/np.linalg.norm(gradn)
    eigenvalue=dot/np.linalg.norm(gradn)**2
    return eigenvalue#, overlap

  return Hg_scalars

def compute_spectrumf(Mv_general, batches, sizev,k,v0=None):
  """Compute eigensystem from M.v(v). Lanczos is a modification of scipy's lanczos to accept a function M.v(v) instead of M directly."""
  def compute_spectrum(params):
    """Compute k leading eigenvalues and eigenvectors using Lanczos."""
    batch = next(batches)
    evals,evecs=lanczos.eigsh(sizev, np.float32, lambda v: Mv_general(params, batch, v), k=k, v0=v0)
    return evals
  return compute_spectrum



#### Operator hierarchy measurements
from jax.lax import stop_gradient
def generate_scalingops(predict,ref_image,ref_vec= np.array([1] + 9*[0])):

    def f_inst(params, image=ref_image, class_vec=ref_vec):
        return predict(params, image)[0].dot(class_vec)


    def df(params, image=ref_image, class_vec=ref_vec):
        return flatten_jax(grad(f_inst)(params, image, class_vec))




    def cont_step_hierarchy(op):
        # @jit
        # def next_op(params, i=ref_image, v=ref_vec, *args):
        #     return flatten_jax(grad(op)(params, *args)).dot(df(params, i, v))
        #jvp factor of 3 faster, becuase of flatten
        @jit
        def next_op(params, i=ref_image, v=ref_vec, *args):
            return jvp(lambda p: op(p, *args), (params,), (grad(f_inst)(params, i, v),) )[1]

        return next_op


    def disc_step_hierarchy(op):
        # @jit
        # def next_op(params, i=ref_image, v=ref_vec, *args):
        #     return flatten_jax(grad(op)(params, *args)).dot(stop_gradient(df(params, i, v)))
        #jvp factor of 3 faster, suspect because of flatten
        @jit
        def next_op(params, i=ref_image, v=ref_vec, *args):
            return jvp(lambda p: op(p, *args), (params,), (stop_gradient(grad(f_inst)(params, i, v)),) )[1]


        return next_op


    Theta = cont_step_hierarchy(f_inst)
    O3 = cont_step_hierarchy(Theta)
    O4 = cont_step_hierarchy(O3)

    Theta_p = disc_step_hierarchy(f_inst)
    O3_p = disc_step_hierarchy(Theta_p)
    O4_p = disc_step_hierarchy(O3_p)
    return Theta,O3,O4,Theta_p,O3_p,O4_p

def generate_Delta_f(predict, ref_image, ref_target):
  @jit
  def Delta_f(params):
    return np.sqrt(np.mean((predict(params, ref_image) - ref_target)**2))

  return Delta_f
