import torch
import re
import time
import numpy as np
from copy import deepcopy
import jax
from jax.interpreters import xla
import jax.numpy as jnp
from jax import tree_util, vmap, pmap
from jax.ops import index, index_update
from scipy.special import log_softmax
from src.utils import util, gpu_util
from src.utils.ntk_computation import predict
from functools import partial, lru_cache
from collections import Counter
from src.utils.ntk_computation.empirical import jacobian_calculator
from src.utils import torch_to_flax_util


def get_full_data(data, indices, noise=0.0):
    X, y = [], []
    for idx in indices:
        Xi, yi = data[idx]['inputs'], data[idx]['labels']
        X.append(Xi.unsqueeze(0))
        y.append(yi)
    X = torch.cat(X)
    if noise > 0.0:
        X += torch.normal(0., 1., X.size()) * noise
    y = torch.tensor(y)
    return X, y


def get_ntk_input_shape(data_config, num_input_channels, old=False):
    crop_size = data_config['transform']['crop_size']
    input_shape = (-1 if old else 1, crop_size, crop_size, num_input_channels)
    return input_shape


def update_ntk_params(params, model, bn_with_running_stats):
    if bn_with_running_stats:
        return update_ntk_params_w_running_stats(params, model)
    else:
        return update_ntk_params_wo_running_stats(params, model)


def update_ntk_params_wo_running_stats(params, model):
    treedef = tree_util.tree_structure(params)
    new_jax_params = tree_util.tree_leaves(params)
    running_mean = None
    running_var = None
    bn_weight = None

    if hasattr(model, 'module'):
        model = model.module

    model.eval()
    for i, (name, param) in enumerate(model.named_parameters()):
        if i == len(new_jax_params) - 1:  # last fc layer bias
            new_param = param.view(1, -1).detach().cpu().numpy()
        elif i == len(new_jax_params) - 2:  # last fc layer weight
            new_param = param.detach().cpu().numpy().swapaxes(1, 0)
        elif re.findall("bn\d.weight", name):  # batch norm weight
            name_list = name.split('.')
            if len(name_list) <= 2:
                running_mean = getattr(getattr(model, name_list[0]), 'running_mean')
                running_var = getattr(getattr(model, name_list[0]), 'running_var')
            else:
                running_mean = getattr(getattr(
                    getattr(model, name_list[0])[int(name_list[1])], name_list[2]), 'running_mean')
                running_var = getattr(getattr(
                    getattr(model, name_list[0])[int(name_list[1])], name_list[2]), 'running_var')
            bn_weight = param

            #new_param = torch.diag(param / (torch.sqrt(running_var + 1e-5))).detach().cpu().numpy()
            new_param = torch.diag(param).detach().cpu().numpy()
        elif re.findall("bn\d.bias", name):
            #param = param - (running_mean / (torch.sqrt(running_var + 1e-5))) * bn_weight
            new_param = param.view(1, 1, 1, -1).detach().cpu().numpy()
        else:
            if len(param.shape) == 1:
                param = param.view(-1, 1, 1, 1)
            new_param = param.detach().cpu().numpy().transpose(2, 3, 1, 0)
        try:
            new_jax_params[i] = index_update(new_jax_params[i], index[:], new_param)
        except:
            import IPython; IPython.embed()

    new_jax_params = tree_util.tree_unflatten(treedef, new_jax_params)
    return new_jax_params


def update_ntk_params_w_running_stats(params, model):
    dtype = 'float64'
    return torch_to_flax_util.transfer_params_from_torch_model(model, params, dtype)


# Todo: this can be optimized by removing the for loop, is this conducted in batches because it doesn't fit in memory?
def batch_apply_fn(apply_fn, params, rng, X, batch_size=1000):
    fx_0 = None
    for i in range(0, X.shape[0], batch_size):
        end_index = min(X.shape[0], i + batch_size)
        X_subset = X[i:end_index]
        if fx_0 is not None:
            fx_0 = jnp.concatenate(
                (fx_0, apply_fn(params, X_subset)), axis=0)
        else:
            fx_0 = apply_fn(params, X_subset)
    return fx_0


def calculate_maximum_ntk_batch_fitting_in_memory(num_classes, params_size, num_devices, coef, trace_axes=()):
    float_size = gpu_util.get_float_size()
    if torch.cuda.device_count() > 0:
        usable_memory = gpu_util.get_usable_gpu_memory_in_bytes()
    else:
        usable_memory = gpu_util.get_usable_cpu_memory_in_bytes()
    space_unit = (usable_memory * num_devices) / float_size
    effective_num_classes = num_classes if len(trace_axes) == 0 else 1
    return np.floor(space_unit / (effective_num_classes * params_size * coef))


def ntk_fn_dynamic_batched(ntk_fn_builder, num_devices, params, num_classes, params_size, x1, x2=None, c=12, trace_axes=(), max_bs=None, get=None):
    x2a = x1 if x2 is None else x2
    n1, n2 = x1.shape[0], x2a.shape[0]
    maximum_batch_size = calculate_maximum_ntk_batch_fitting_in_memory(num_classes, params_size, num_devices, c, trace_axes)
    maximum_batch_size = min(maximum_batch_size, max_bs)
    assert n1 * n2 % (num_devices * num_devices) == 0, "both {} and {} should be divisible by #GPUs!".format(n1, n2)
    optimal_batch_size = np.gcd(n1, n2) // num_devices
    divisors = util.get_sorted_divisors(optimal_batch_size)[::-1]
    batch_size = divisors[divisors < maximum_batch_size][0]
    print('n1 - {} n2 - {} batch size - {}'.format(n1, n2, batch_size))
    k = ntk_fn_builder(batch_size=batch_size)
    return k(x1, x2, params) if params is not None else k(x1, x2, get)


def pad_kernels_to_same_shape(k_test_train, k_train_train, fx_train_0, y_train_onehot, n, nt, num_classes, padded_size):
    k, p = num_classes, padded_size - 1  # the last one is to be added in the computations
    # n x n x k x k --> n x p x k x k
    k_train_train = np.concatenate((k_train_train, np.zeros((n, p - n, k, k))), axis=1)
    # create (p-n) x p x k x k
    train_kernel_second_block_row = np.concatenate((
        np.zeros((p - n, n, k, k)), np.eye((p - n) * k).reshape(p - n, k, p - n, k).transpose(0, 2, 1, 3)), axis=1)
    # train_train --> p x p x k x k
    k_train_train = np.concatenate((k_train_train, train_kernel_second_block_row), axis=0)
    # nt x n x k x k --> nt x p x k x k
    k_test_train = np.concatenate((k_test_train, np.zeros((nt, p - n, k, k))), axis=1)
    # n x k --> p x k
    fx_train_0 = np.concatenate((fx_train_0, np.zeros((p - n, k))), axis=0)
    y_train_onehot = np.concatenate((y_train_onehot, np.zeros((p - n, k))), axis=0)
    return k_test_train, k_train_train, fx_train_0, y_train_onehot


def calculate_jacobians(X, jacobian_calculator=None, p=-1, num_classes=10, batch_size=40):
    js = np.empty((len(X), num_classes, p))
    for i in range(int(len(X) // batch_size)):
        batch = X[i * batch_size:(i + 1) * batch_size]
        j = jax.device_get(jacobian_calculator(gpu_util.gpu_split(batch)))
        js[i * batch_size:(i + 1) * batch_size] = j.reshape(batch_size, num_classes, -1)
    return js


def calculate_kernel(j1, j2=None, kernel_calculator=None, batch_size=40):
    p = j1.shape[2]
    j2 = j1 if j2 is None else j2
    ker = np.zeros((j1.shape[0], j1.shape[1], j2.shape[1], j2.shape[0]))
    for b in range(int(p // batch_size)):
        ker += jax.device_get(kernel_calculator(
            gpu_util.gpu_split(j1[:, :, b * batch_size:(b + 1) * batch_size].T),
            gpu_util.gpu_split(j2[:, :, b * batch_size:(b + 1) * batch_size].T)
        )).sum(axis=0)
    return ker


def calculate_weight_update(f, rng, params, X_train, X_test, k_train_train, k_test_train, k_test_test,
                            y_train_onehot, y_test_onehot, fx_train_0, fx_test_0, batch_size=100, selections=None):
    p_size = util.get_params_size(params)
    num_classes = y_train_onehot.shape[1]
    jac = jax.pmap(partial(jacobian_calculator(partial(f, **{'rng': rng}), vmap_axes=0), params=params))
    calc_jacs = partial(calculate_jacobians, jacobian_calculator=jac, p=p_size, num_classes=num_classes)

    rhs = calculate_weight_update_rhs(k_train_train, k_test_train, k_test_test,
                                      y_train_onehot, y_test_onehot, fx_train_0, fx_test_0, selections=None)

    if selections is not None and len(selections) > 0:
        X = np.concatenate((X_train, X_test[selections]), axis=0)
    else:
        X = X_train

    new_params = deepcopy(params)

    def update_params(ps, update):
        last_p = 0
        flat_ps, p_def = jax.tree_flatten(ps)
        for i in range(len(flat_ps)):
            p = flat_ps[i]
            flat_ps[i] -= update[last_p:last_p + p.size].reshape(p.shape)
            last_p = last_p + p.size
        return jax.tree_unflatten(p_def, flat_ps)

    params_update = np.zeros(p_size)
    for b in range(int(len(X) / batch_size)):
        data = X[b * batch_size:(b + 1) * batch_size]
        jT = calc_jacs(data, batch_size=len(data)).T
        params_update += np.tensordot(jT, rhs[b * batch_size:(b + 1) * batch_size].T)

    return update_params(new_params, params_update)


def calculate_weight_update_rhs(k_train_train, k_test_train, k_test_test,
                                y_train_onehot, y_test_onehot, fx_train_0, fx_test_0, selections=None):
    if selections is not None and len(selections) > 0:
        k_point = k_test_test[np.ix_(selections, selections)]
        k_row = k_test_train[selections]
        k_train_train = np.concatenate(
            (np.concatenate((k_train_train, k_row), axis=0),
             np.concatenate((k_row, k_point), axis=1).transpose(1, 0, 3, 2)), axis=1)
        selected_y_train = y_test_onehot[selections]
        y_train_onehot = np.concatenate((y_train_onehot, selected_y_train), axis=0)
        selected_fx_train = fx_test_0[selections]
        fx_train_0 = np.concatenate((fx_train_0, selected_fx_train), axis=0)

    C = predict.jit_prepare_cho_solve_staged(k_train_train)[0]
    x_non_channel_shape = k_train_train.shape[1::2]
    rhs = y_train_onehot - fx_train_0
    ker_rhs = predict.jit_raw_cho_solve_no_multiply_staged(C, rhs, (), x_non_channel_shape)

    return ker_rhs


# Assuming everything will fit in memory
def construct_ntk_kernels(f, params, rng, X_train, X_test, num_classes):
    jac = jax.pmap(partial(jacobian_calculator(f, vmap_axes=0), params=params))
    p = util.get_params_size(params)
    calc_jacs = partial(calculate_jacobians, jacobian_calculator=jac, p=p, num_classes=num_classes)

    ker = jax.pmap(lambda j1, j2: jnp.tensordot(j1.T, j2, axes=1))
    calc_ker = partial(calculate_kernel, kernel_calculator=ker)

    train_bigger = len(X_train) > len(X_test)
    X_big, X_small = (X_train, X_test) if train_bigger else (X_test, X_train)

    prev_time = time.time()
    j_small = calc_jacs(X_small, batch_size=100)
    # j_big = calc_jacs(X_big, batch_size=100)
    j_big = None
    elapsed = time.time() - prev_time
    print(f'Calculating jacobians took {elapsed:.3f}s.')

    prev_time = time.time()
    # k_small_small = calc_ker(j_small, None, batch_size=int(5*1e5))
    # k_big_small = calc_ker(j_big, j_small, batch_size=int(5*1e5))
    # k_big_big = np.tensordot(j_big, j_big.T, axes=1)
    k_big_big = None
    k_small_small = np.tensordot(j_small, j_small.T, axes=1)
    k_big_small = None
    elapsed = time.time() - prev_time
    print(f'Calculating kernels took {elapsed:.3f}s.')

    return j_big, j_small, k_big_big, k_big_small, k_small_small


def memory_efficient_diagonal_regularize(A, diag_reg, n, k, batch=20000):
    dimension = n * k
    A = np.array(A.transpose(0, 2, 1, 3).reshape((n * k, n * k)))
    diag_reg *= np.trace(A) / dimension
    for i in range(0, n*k, batch):  # So that we don't create a new nk * nk matrix (identity)
        s = i
        e = min(i+batch, n * k)
        try:
            A[s:e, s:e] += diag_reg * np.eye(e - s)
        except:
            print('Diag reg problem')
            import IPython; IPython.embed()
    return A.reshape(n, k, n, k).transpose(0, 2, 1, 3)


def memory_efficient_fixed_diagonal_regularize(A, diag_reg, n, k, batch=20000):
    dimension = n * k
    for i in range(0, n*k, batch):  # So that we don't create a new nk * nk matrix (identity)
        s = i
        e = min(i+batch, n * k)
        try:
            A[s:e, s:e] += diag_reg * np.eye(e - s)
        except:
            print('Diag reg problem')
            import IPython; IPython.embed()
    return A


def transpose_wrt_trace_axes(mat, transpose_args, trace_axes=()):
    if len(trace_axes) == 0:
        return mat.transpose(*transpose_args)
    transpose_args = list(filter(lambda x: x in [0, 1], transpose_args))
    return mat.transpose(*transpose_args)