import random
from pathlib import Path

import torch
import os
import re
import time
import json
import numpy as np
import gc
from copy import deepcopy
import jax
import jaxlib
from jax.interpreters import xla
import jax.numpy as jnp
import scipy as sp
from jax import tree_util, vmap, pmap, lax
from jax.ops import index, index_update
from scipy.special import log_softmax
from src.utils import util, gpu_util
from functools import partial, lru_cache
from collections import Counter
from neural_tangents.utils import utils as nt_utils
from src.utils.ntk_computation.empirical import jacobian_calculator
from src.utils.ntk_computation import predict
from src.utils.ntk_util import ntk_fn_dynamic_batched, get_full_data, batch_apply_fn


def calculate_accuracy(y_pred, y):
    return (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum() / len(y)


def make_2d(k):
    if len(k.shape) == 2:
        return k
    assert len(k.shape) == 4, "Kernel must have 4 dimensions!"
    assert k.shape[2] == k.shape[3], "3rd and 4th axes must have the same dimension"
    n, nt, l, _ = k.shape
    return k.transpose(0, 2, 1, 3).reshape(n*l, nt*l)


def compute_and_save_ntk(dataloaders, models, epoch, model_config, data_config, save_dir, num_devices=1):

    labeled_dataset = dataloaders['train'].dataset
    labeled_set = dataloaders['labeled_set']
    validation_dataset = dataloaders['val'].dataset
    add_num = data_config['al_params']['add_num']

    apply_fn, loss_fn, params, rng = \
        models['apply_fn'], models['loss_fn'], models['ntk_params'], models['rng']

    num_classes, diagonality_config = \
        model_config['model_arch']['num_classes'], model_config['diagonality']

    ntk_fn_builder = models['ntk_fn_builder']
    ntk_trace_fn_builder = models['ntk_trace_fn_builder']
    ntk_inf_fn_builder = models['ntk_inf_fn_builder']
    dynamic_ntk_batch_coef = diagonality_config.get('full_comp_coef', 12)
    dynamic_ntk_trace_batch_coef = diagonality_config.get('trace_comp_coef', 12)
    dynamic_ntk_inf_batch_coef = diagonality_config.get('inf_comp_coef', 12)
    max_ntk_batch_size = diagonality_config.get('max_ntk_batch_size', 100000)
    max_ntk_trace_batch_size = diagonality_config.get('max_ntk_trace_batch_size', 100000)
    max_ntk_inf_batch_size = diagonality_config.get('max_ntk_inf_batch_size', 100000)

    save_full_kernel = diagonality_config.get('save_full_kernel', False)
    compute_full_kernel = diagonality_config.get('compute_full_kernel', True)
    compute_full_predictions = diagonality_config.get('compute_full_predictions', True)

    save_inf_kernel = diagonality_config.get('save_inf_kernel', False)
    compute_inf_kernel = diagonality_config.get('compute_inf_kernel', False) and epoch == 0
    compute_inf_predictions = diagonality_config.get('compute_inf_predictions', True)

    flax_apply_batch_size = diagonality_config.get('flax_apply_batch_size', 50)

    compute_test_kernels_first = diagonality_config.get('compute_test_kernels_first', False)

    save_dir = Path(save_dir).joinpath('diagonality')
    save_dir.mkdir(parents=True, exist_ok=True)

    dynamic_batched_ntk_fn = partial(
        ntk_fn_dynamic_batched, ntk_fn_builder,
        num_devices, params, num_classes, models['ntk_params_size'],
        c=dynamic_ntk_batch_coef, trace_axes=(), max_bs=max_ntk_batch_size)
    dynamic_batched_ntk_trace_fn = partial(
        ntk_fn_dynamic_batched, ntk_trace_fn_builder,
        num_devices, params, num_classes, models['ntk_params_size'],
        c=dynamic_ntk_trace_batch_coef, trace_axes=(-1,), max_bs=max_ntk_trace_batch_size)
    dynamic_batched_ntk_inf_fn = partial(
        ntk_fn_dynamic_batched, ntk_inf_fn_builder,
        num_devices, None, num_classes, models['ntk_params_size'],
        c=dynamic_ntk_inf_batch_coef, trace_axes=(-1,), max_bs=max_ntk_inf_batch_size, get='ntk')

    print('Min GPU available at start:', gpu_util.get_free_gpu_memory_in_bytes())

    X_train, y_train = get_full_data(labeled_dataset, labeled_set)
    if model_config['backbone'] != 'fcn':
        X_train = jnp.asarray(X_train.detach().cpu().numpy().transpose(0, 2, 3, 1), dtype=jnp.float64)
    else:
        X_train = jnp.asarray(X_train.detach().cpu().numpy(), dtype=jnp.float64)
    y_train_onehot = np.zeros((len(y_train), num_classes))
    y_train_onehot[np.arange(len(y_train)), y_train] = 1
    y_train_onehot = jnp.array(y_train_onehot, dtype=jnp.float64)
    print('X train shape: {}'.format(X_train.shape), flush=True)

    if add_num > 0:
        X_test, y_test = get_full_data(validation_dataset, list(range(add_num)))
        if model_config['backbone'] != 'fcn':
            X_test = jnp.asarray(X_test.detach().cpu().numpy().transpose(0, 2, 3, 1), dtype=jnp.float64)
        else:
            X_test = jnp.asarray(X_test.detach().cpu().numpy(), dtype=jnp.float64)
        y_onehots = jnp.eye(num_classes, dtype=jnp.float64)
        y_test_onehot = np.zeros((len(y_test), num_classes))
        y_test_onehot[np.arange(len(y_test)), y_test] = 1
        y_test_onehot = jnp.array(y_test_onehot, dtype=jnp.float64)
        print('X test shape: {}'.format(X_test.shape), flush=True)

    # compute initial predictions
    prev_time = time.time()
    fx_train_0 = batch_apply_fn(
        apply_fn, params, rng, X_train, batch_size=flax_apply_batch_size)  # Todo: is this okay?
    if add_num > 0:
        fx_test_0 = batch_apply_fn(
            apply_fn, params, rng, X_test, batch_size=flax_apply_batch_size)
    print('initial prediction is done: {}'.format(time.time() - prev_time), flush=True)

    X_train = jax.device_put(X_train)

    if compute_full_kernel:

        if compute_test_kernels_first:
            print('Computing test kernels...')
            prev_time = time.time()
            k_test_test = np.array(dynamic_batched_ntk_fn(X_test, None))
            print('Full test test is done: {}'.format(time.time() - prev_time), flush=True)

        print('Computing full kernels...')

        prev_time = time.time()
        k_train_train = np.array(dynamic_batched_ntk_fn(X_train, None))
        print('Full train train is done: {}'.format(time.time() - prev_time), flush=True)
        prev_time = time.time()
        k_train_train = np.array(dynamic_batched_ntk_fn(X_train, None))
        print('Full train train is done: {}'.format(time.time() - prev_time), flush=True)

        if add_num > 0:
            prev_time = time.time()
            k_test_train = np.array(dynamic_batched_ntk_fn(X_test, X_train))  # nt x n x knt x kn
            print('Full test train is done: {}'.format(time.time() - prev_time), flush=True)

        if save_full_kernel:
            np.save(str(save_dir.joinpath('kdd_{}'.format(epoch)).absolute()), k_train_train)
            if add_num > 0:
                np.save(str(save_dir.joinpath('ktd_{}'.format(epoch)).absolute()), k_test_train)

    if compute_inf_kernel:

        print('Computing infinite kernels...')

        if model_config['backbone'] == 'fcn':
            X_train_flat = jnp.reshape(X_train, newshape=(X_train.shape[0], -1))
        else:
            X_train_flat = X_train

        prev_time = time.time()
        inf_k_train_train = np.array(dynamic_batched_ntk_inf_fn(X_train_flat, None))
        print('Inf train train is done: {}'.format(time.time() - prev_time), flush=True)

        if add_num > 0:

            if model_config['backbone'] == 'fcn':
                X_test_flat = jnp.reshape(X_test, newshape=(X_test.shape[0], -1))
            else:
                X_test_flat = X_test

            prev_time = time.time()
            inf_k_test_train = np.array(dynamic_batched_ntk_inf_fn(X_test_flat, X_train_flat))
            print('Inf test train is done: {}'.format(time.time() - prev_time), flush=True)

        if save_inf_kernel:
            np.save(str(save_dir.joinpath('idd_{}'.format(epoch)).absolute()), inf_k_train_train)
            if add_num > 0:
                np.save(str(save_dir.joinpath('itd_{}'.format(epoch)).absolute()), inf_k_test_train)

    if compute_test_kernels_first:
        print('Computing test kernels...')
        prev_time = time.time()
        k_test_test = np.array(dynamic_batched_ntk_trace_fn(X_test, None))
        print('Full test test is done: {}'.format(time.time() - prev_time), flush=True)

    print('Computing pseudo kernels...')

    prev_time = time.time()
    pseudo_k_train_train = np.array(dynamic_batched_ntk_trace_fn(X_train, None))
    print('Pseudo train train is done: {}'.format(time.time() - prev_time), flush=True)
    prev_time = time.time()
    pseudo_k_train_train = np.array(dynamic_batched_ntk_trace_fn(X_train, None))
    print('Pseudo train train is done: {}'.format(time.time() - prev_time), flush=True)

    if add_num > 0:
        prev_time = time.time()
        pseudo_k_test_train = np.array(dynamic_batched_ntk_trace_fn(X_test, X_train))
        print('Pseudo test train is done: {}'.format(time.time() - prev_time), flush=True)

    # np.save(str(save_dir.joinpath('pdd_{}'.format(epoch)).absolute()), pseudo_k_train_train)
    if add_num > 0:
        np.save(str(save_dir.joinpath('ptd_{}'.format(epoch)).absolute()), pseudo_k_test_train)
        np.save(str(save_dir.joinpath('ytest_{}'.format(epoch)).absolute()), y_test_onehot)

    info = {}

    if compute_full_kernel:

        k_avg = np.abs(k_train_train).mean(axis=(0, 1))
        diagonal_percentage = k_avg.trace() / k_avg.sum()
        print('Diagonal percentage:', diagonal_percentage)
        info['diagonal_percentage'] = float(diagonal_percentage)

        # 2D-izing
        k_train_train = make_2d(k_train_train)
        k_test_train = make_2d(k_test_train)
        n2d = k_train_train.shape[0]
        nt2d = k_test_train.shape[0]

        if compute_full_predictions and add_num > 0:
            preds_full = predict.direct_gradient_descent_mse_predict(
                k_train_train, y_train_onehot.reshape(n2d, -1),
                fx_train_0.reshape(n2d, -1), fx_test_0.reshape(nt2d, -1), k_test_train
            ).reshape(-1, num_classes)
            full_acc = calculate_accuracy(preds_full, y_test_onehot)
            full_nan_count = np.isnan(preds_full).sum()
            print('Full acc:', full_acc, 'Nan count:', full_nan_count)
            info['full_acc'] = float(full_acc)
            info['full_nan_count'] = float(full_nan_count)
            np.save(str(save_dir.joinpath('preds_full_{}'.format(epoch)).absolute()), preds_full)

    if compute_inf_kernel and compute_inf_predictions and add_num > 0:

        preds_inf = predict.direct_gradient_descent_mse_predict(
            inf_k_train_train, y_train_onehot, None, None, inf_k_test_train)
        inf_acc = calculate_accuracy(preds_inf, y_test_onehot)
        inf_nan_count = np.isnan(preds_inf).sum()
        print('Inf acc:', inf_acc, 'Nan count:', inf_nan_count)
        info['inf_acc'] = float(inf_acc)
        info['inf_nan_count'] = float(inf_nan_count)
        np.save(str(save_dir.joinpath('preds_inf_{}'.format(epoch)).absolute()), preds_inf)

    if add_num > 0:

        preds_pseudo = predict.direct_gradient_descent_mse_predict(
            pseudo_k_train_train, y_train_onehot, fx_train_0, fx_test_0, pseudo_k_test_train)
        pseudo_acc = calculate_accuracy(preds_pseudo, y_test_onehot)
        pseudo_nan_count = np.isnan(preds_pseudo).sum()
        print('Pseudo acc:', pseudo_acc, 'Nan count:', pseudo_nan_count)
        info['pseudo_acc'] = float(pseudo_acc)
        info['pseudo_nan_count'] = float(pseudo_nan_count)
        np.save(str(save_dir.joinpath('preds_pseudo_{}'.format(epoch)).absolute()), preds_pseudo)

        model_acc = calculate_accuracy(fx_test_0, y_test_onehot)
        print('Model acc:', model_acc)
        info['model_acc'] = float(model_acc)
        np.save(str(save_dir.joinpath('preds_model_{}'.format(epoch)).absolute()), fx_test_0)

    with open(str(save_dir.joinpath('info_{}.json'.format(epoch)).absolute()), 'w') as fp:
        json.dump(info, fp)
