import jax
import jax.numpy as jnp
import torch
import numpy as np
import neural_tangents as nt
from functools import partial
import torch.nn.init as init

from src.core.models import wide_resnet, resnet, fcn, convnet, ll_resnet
from src.utils.ntk_util import get_ntk_input_shape
from src.utils.ntk_computation.empirical import empirical_ntk_fn
from src.utils import util, model_utils

MODELS = {
    'wide_resnet': wide_resnet.torch.Wide_ResNet,
    'resnet18': resnet.torch.ResNet18,
    'resnet34': resnet.torch.ResNet34,
    'resnet50': resnet.torch.ResNet50,
    'resnet101': resnet.torch.ResNet101,
    'resnet152': resnet.torch.ResNet152,
    'resnet200': resnet.torch.ResNet200,
    'll_resnet18': ll_resnet.torch.ResNet18,
    'fcn': fcn.torch.FCNet,
    'convnet': convnet.torch.ConvNet
}

NTK_MODELS = {
    'wide_resnet_w': wide_resnet.flax.WideResNetNTK,
    'resnet18_w': resnet.flax.ResNet18,
    'resnet34_w': resnet.flax.ResNet34,
    'resnet50_w': resnet.flax.ResNet50,
    'resnet101_w': resnet.flax.ResNet101,
    'resnet152_w': resnet.flax.ResNet152,
    'resnet200_w': resnet.flax.ResNet200,
    'll_resnet18_w': ll_resnet.flax.ResNet18,
    'fcn_w': fcn.flax.FCNetNTK,
    'convnet_w': convnet.flax.ConvNet
}

STAX_MODELS = {
    'wide_resnet': wide_resnet.stax.WideResNet,
    'resnet18': resnet.stax.ResNet18,
    'resnet34': resnet.stax.ResNet34,
    'fcn': fcn.stax.FCNet,
    'convnet': convnet.stax.ConvNet
}


def ntk_initialize(m):
    classname = m.__class__.__name__
    if (classname.find('Conv') != -1 or classname.find('Linear') != -1) and classname != 'ConvNet':
        init.normal_(m.weight, 0., 1. / np.sqrt(init._calculate_correct_fan(m.weight, 'fan_in')))
        if m.bias is not None:
            if len(m.bias.shape) > 1:
                init.normal_(m.bias, 0., 0.05 / np.sqrt(init._calculate_correct_fan(m.bias, 'fan_in')))
            else:
                init.normal_(m.bias, 0., 0.05 / np.sqrt(len(m.bias)))
    elif classname.find('BatchNorm') != -1:
        init.constant_(m.weight, 1)
        init.constant_(m.bias, 0)


def initialize_model(model, logger, init_method, init_checkpoint=None):
    if init_method == 'ntk':
        model.apply(ntk_initialize)
    elif init_method == 'pretrained':
        if init_checkpoint is None:
            logger.error('Specify valid backbone or model type among {}.'.format(MODELS.keys()))
            exit()
        model.apply(ntk_initialize)
        checkpoint = torch.load(init_checkpoint, map_location='cpu')
        model.load_state_dict(checkpoint, strict=False)
    else:
        logger.error('Specify valid initialization method, got {}.'.format(init_method))
        exit()


def build(model_config, data_config, logger):
    backbone = model_config['backbone']
    model_arch = model_config['model_arch']
    data_name = data_config['name']
    num_classes = model_arch['num_classes']
    bn_with_running_stats = model_arch.pop('bn_with_running_stats', True)

    if data_name in ['cifar10', 'cifar100', 'svhn', 'tiny_imagenet']:
        num_input_channels = 3
    else:
        num_input_channels = 1
    model_arch['num_input_channels'] = num_input_channels

    # Build a model
    models = {}
    if backbone in MODELS:

        ####################
        # Building PyTorch model
        ####################

        model_activation = model_arch.get('activation', None)
        if model_activation is not None:
            model_arch['activation'] = model_utils.construct_nonlinearity(model_activation, 'torch')

        model = MODELS[backbone](**model_arch)
        models['model'] = model

        initialize_model(model, logger, model_config['init_method'], model_config.get('init_checkpoint', None))

        logger.info('A model {} is built.'.format(backbone))

        ####################
        # Building stax model
        ####################

        try:
            init_fn, apply_fn, ntk_inf_fn = STAX_MODELS[backbone](**model_arch)
            ntk_inf_fn_batch_builder = partial(nt.batch, kernel_fn=ntk_inf_fn, device_count=-1, store_on_device=False)
        except:
            ntk_inf_fn = None
            ntk_inf_fn_batch_builder = None
            logger.warn('Couldn\'t build the stax model, skipping')

        ####################
        # Building Flax model
        ####################

        # return models

        with_or_without = '_w' if bn_with_running_stats else '_wo'
        backbone = backbone + with_or_without

        if model_activation is not None:
            model_arch['activation'] = model_utils.construct_nonlinearity(model_activation, 'flax')

        if bn_with_running_stats:
            ntk_model = NTK_MODELS[backbone](**model_arch)
            init_fn, apply_fn = ntk_model.init, ntk_model.apply
            apply_fn = partial(apply_fn, mutable=False)
        else:
            init_fn, apply_fn, _ = NTK_MODELS[backbone](**model_arch)
            ntk_model = None

        # Define a loss function
        loss_fn = None

        ntk_implementation = 3

        # Initialize ntk params
        rng = jax.random.PRNGKey(1313)

        if not bn_with_running_stats:
            apply_fn = partial(apply_fn, **{'rng': rng})

        def apply_fn_trace(params, x):
            out = apply_fn(params, x)
            return np.sum(out, axis=-1) / out.shape[-1] ** 0.5
            # return out[:, 0]

        ntk_fn = empirical_ntk_fn(
            apply_fn, vmap_axes=0, implementation=ntk_implementation, trace_axes=())
        ntk_trace_fn = empirical_ntk_fn(
            apply_fn_trace, vmap_axes=0, implementation=ntk_implementation, trace_axes=())

        ntk_fn_batch_builder = partial(nt.batch, kernel_fn=ntk_fn, device_count=-1, store_on_device=False)
        ntk_trace_fn_batch_builder = partial(nt.batch, kernel_fn=ntk_trace_fn, device_count=-1, store_on_device=False)

        single_device_ntk_fn = partial(nt.batch, kernel_fn=ntk_fn, device_count=1, store_on_device=False)
        single_device_ntk_trace_fn = partial(nt.batch, kernel_fn=ntk_trace_fn, device_count=1, store_on_device=False)

        if bn_with_running_stats:
            ntk_params = init_fn(rng, jnp.ones(get_ntk_input_shape(data_config, num_input_channels)))
        else:
            _, ntk_params = init_fn(rng, get_ntk_input_shape(data_config, num_input_channels, old=True))


        models.update({
            'ntk_model': ntk_model,
            'ntk_fn_builder': ntk_fn_batch_builder,
            'ntk_trace_fn_builder': ntk_trace_fn_batch_builder,
            'ntk_inf_fn_builder': ntk_inf_fn_batch_builder,
            'ntk_inf_fn': ntk_inf_fn,
            'single_device_ntk_fn': single_device_ntk_fn,
            'single_device_ntk_trace_fn': single_device_ntk_trace_fn,
            'ntk_params': ntk_params,
            'ntk_params_size': util.get_params_size(ntk_params),
            'apply_fn': apply_fn,
            'loss_fn': loss_fn,
            'rng': rng,
            'use_ntk': True
        })
        logger.info('A NTK model {} is built.'.format(backbone))

    else:
        logger.error(
            'Specify valid backbone or model type among {}.'.format(MODELS.keys())
        ); exit()

    return models

