import jax
import jax.numpy as jnp
import torch

import torchvision.datasets as dset
import torchvision.transforms as transforms
import neural_tangents as nt

import numpy as np
import flax
import flax.linen as nn
import optax as tx
import neural_tangents.stax as stax

import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'


from typing import Any, Callable, Sequence, Tuple
from flax.training import train_state, checkpoints

import matplotlib.pyplot as plt
import functools
import operator
import fire

import data
from utils import _sub, multiply_by_scalar, get_dot_product, bind
import models

@functools.partial(jax.jit, static_argnames=('train', 'has_bn', 'use_base_params'))
def get_training_loss_l2(params, images, labels, net_train_state, l2 = 0., train = False, has_bn = False, batch_stats = None, use_base_params = False, label_scale = 1, init_params = None):
    # outputs = net_forward_apply(params, images, train = False)['fc'][:, :10]


    variables = {'params': params}

    if init_params is not None:
        variables_init = {'params': init_params}

    # if net_train_state.base_params is not None:
    #     variables['base_params'] = net_train_state.base_params 

    outputs, _ = net_train_state.apply_fn(variables, images, train = train, mutable=[])

    # loss = jnp.sum(0.5 * (outputs - labels)**2)

    # loss = -jnp.sum(jax.nn.one_hot(labels) * jax.nn.log_softmax(outputs, 1))/len(labels)
    # loss = jnp.mean(bce(logits = outputs, labels = labels.astype(jnp.float32)))
    
    loss = jnp.mean(0.5 * ((outputs - labels)) **2)
    # loss = jnp.sum(0.5 * (outputs - labels)**2, axis = -1).mean()
    # loss = jnp.mean(0.5 * (outputs - labels)**2, axis = -1).sum()
    
    # if type(l2) is dict:
    #     loss += 0.5 * l2['body'] * get_dot_product(params, params)
    #     # head = params['tangent_params']['kernel']
    # else:
    #     if 'base_params' in params:
    #         loss += 0.5 * l2 * get_dot_product(params['tangent_params'], params['tangent_params'])
    #     else:
    #         loss += 0.5 * l2 * get_dot_product(params, params)


    # acc = jnp.mean(outputs.argmax(1) == labels.argmax(1))
    # acc = 0
    # acc = jnp.mean((outputs > 0).reshape(-1) == labels > 0.5)
    
    if labels.shape[-1] == 1:
        acc = jnp.mean((outputs > 0) == (labels > 0))
        n_correct = jnp.sum((outputs > 0) == (labels > 0))
    else:
        acc = jnp.mean(outputs.argmax(1) == labels.argmax(1))
        n_correct = jnp.sum(outputs.argmax(1) == labels.argmax(1))
    # n_correct = jnp.sum(outputs.argmax(1) == labels.argmax(1))
    
    # loss = loss/(labels.shape[0] * labels.shape[1])
    # loss = loss
    

    return loss, [None, acc, n_correct]

def bce(logits, labels):
    
    log_p = jax.nn.log_sigmoid(logits)
    log_not_p = jax.nn.log_sigmoid(-logits)
    # print(labels.shape)
    # print(log_p.shape)
    labels = labels.reshape(-1, 1)
    return -labels * log_p - (1. - labels) * log_not_p

@functools.partial(jax.jit, static_argnames=('has_bn', 'train', 'update_ema', 'use_base_params', 'use_dp'))
def do_training_step(train_state, training_batch, l2 = 0., has_bn = False, train = True, update_ema = False, ema_decay = 0.995, use_base_params = False, use_dp = False, label_scale = None, init_params = None):
    images = training_batch['images']
    labels = training_batch['labels']
    
    if has_bn:
        batch_stats = train_state.batch_stats
    else:
        batch_stats = None

    # get_training_loss_l2(train_state.params, images, labels, train_state, l2 = l2, train = train, has_bn = has_bn, batch_stats = batch_stats, use_base_params = use_base_params)
    
    if use_dp:
        # loss, (new_batch_stats, acc, _) = get_training_loss_l2(train_state.params, images, labels, train_state, l2 = l2, train = train, has_bn = has_bn, batch_stats = batch_stats, use_base_params = use_base_params)
        (loss, (_, acc, _)), grad = jax.vmap(jax.value_and_grad(bind(get_training_loss_l2, ..., ..., ..., train_state, l2 = l2, train = train, has_bn = has_bn, batch_stats = batch_stats, use_base_params = use_base_params), has_aux = True), in_axes = [None, 0, 0])(train_state.params, images[:, None], labels[:, None])
        loss = jnp.mean(loss)
        acc = jnp.mean(acc)
    else:
        (loss, (new_batch_stats, acc, _)), grad = jax.value_and_grad(get_training_loss_l2, argnums = 0, has_aux = True)(train_state.params, images, labels, train_state, l2 = l2, train = train, has_bn = has_bn, batch_stats = batch_stats, use_base_params = use_base_params, label_scale = label_scale)
    
    
        
        
    if has_bn:
        new_state = train_state.apply_gradients(grads = grad, batch_stats = new_batch_stats['batch_stats'], train_it = train_state.train_it + 1)
    else:
        new_state = train_state.apply_gradients(grads = grad, train_it = train_state.train_it + 1)
    
    if update_ema:
        new_ema_hidden, new_ema_average = get_updated_ema(new_state.params, new_state.ema_hidden, ema_decay, new_state.train_it, order = 1)
        new_state = new_state.replace(ema_average = new_ema_average, ema_hidden = new_ema_hidden)
    
    return new_state, (loss, grad)


class TrainStateWithBatchStats(train_state.TrainState):
    batch_stats: flax.core.FrozenDict
    train_it: int
    ema_hidden: Any = None
    ema_average: Any = None
    base_params: Any = None


@functools.partial(jax.jit, static_argnames=('train', 'has_bn', 'use_base_params', 'use_softplus', 'use_solve', 'stat_loss_only'))
def get_recon_loss(all_params, init_params, net_params, labels, image_train_state, l2 = 0., train = False, has_bn = False, batch_stats = None, use_base_params = False, beta = 1., use_softplus = True, img_min = 0, img_max = 1, use_solve = True, stat_loss_only = False):
    # outputs = net_forward_apply(params, images, train = False)['fc'][:, :10]6
    
    # net_params = multiply_by_scalar(net_params, 1/jnp.sqrt(get_dot_product(net_params, net_params)))
    
    image_params = all_params['images']
    # image_params = jnp.concatenate([train_images for i in range(amp_factor)], 0)
    dual_params = all_params['duals']
    
    stupid_init_params = init_params
    
    labels_fixed = labels
    labels_fixed = labels_fixed - image_train_state.apply_fn({'params': stupid_init_params}, image_params, train = train, mutable=['batch_stats'], use_softplus = use_softplus, beta = beta)[0]
    
    
    def fwd(net_params):
        if has_bn:
            variables = {'params': net_params, 'batch_stats': batch_stats}
        else:
            variables = {'params': net_params}
            
        # if net_train_state.base_params is not None:
        #     variables['base_params'] = net_train_state.base_params 

        if use_base_params:
            outputs, new_batch_stats = image_train_state.apply_fn(variables, image_params, train = train, mutable=['batch_stats'], use_base_params = use_base_params, use_softplus = use_softplus, beta = beta)
        else:
            outputs, new_batch_stats = image_train_state.apply_fn(variables, image_params, train = train, mutable=['batch_stats'], use_softplus = use_softplus, beta = beta)
        
        
        
        # print(labels_fixed.shape)
        print(outputs.shape)
        return jnp.mean(dual_params * outputs), jnp.mean((labels_fixed - outputs) ** 2)
    
    
    if has_bn:
        variables = {'params': net_params, 'batch_stats': batch_stats}
    else:
        variables = {'params': net_params}

    # if net_train_state.base_params is not None:
    #     variables['base_params'] = net_train_state.base_params 

    
    outputs, new_batch_stats = image_train_state.apply_fn(variables, image_params, train = train, mutable=['batch_stats'], use_softplus = use_softplus, beta = beta, return_feat = True)
    
    outputs, feat = outputs
        
#     value_loss = jnp.mean((labels_fixed - outputs) ** 2)

    if stat_loss_only:
        value_loss = 0
    elif use_solve:
        K_final = feat @ feat.T

        if feat.shape[1] < feat.shape[0]:
            K_final = K_final + (1e-3 * jnp.trace(K_final) * jnp.eye(K_final.shape[0]))/K_final.shape[0]

        
        
    #     eigs = jnp.linalg.eigvalsh(K_final)
    #     eig_loss = 1/eigs[0]


        value_loss = jnp.mean(((labels_fixed - outputs).T @ jnp.linalg.solve(K_final, (labels_fixed - outputs)))/(labels_fixed - outputs).shape[0])
    else:
        if feat.shape[1] < feat.shape[0]:
            K_final = feat.T @ feat

        else:
            K_final = feat @ feat.T

        value_loss = jnp.mean((labels_fixed - outputs) ** 2)/(jnp.linalg.eigvalsh(K_final)[0])

    # value_loss = jnp.mean(((labels_fixed - outputs).T @ jnp.linalg.solve(K_final, (labels_fixed - outputs))))
    # print(value_loss.shape)
    # print((labels_fixed - outputs).shape[0])

    # eig_loss = 20 * eig_loss
    
    
    grad, _ = jax.grad(fwd, has_aux = True)(net_params)
    
    
    delta_params = _sub(net_params, stupid_init_params)
    
    
    diff = _sub(multiply_by_scalar(delta_params, 1), grad)

    stationary_loss = get_dot_product(diff, diff)
    
    
    img_loss = (jnp.mean(jax.nn.relu(image_params - img_max[None, None, None])**2 + jax.nn.relu(- image_params + img_min[None, None, None])**2)) * image_params.shape[0]
    
    loss = value_loss + stationary_loss + 0.1 * img_loss #+ eig_loss

    
    return loss, (outputs, outputs ,0)


@functools.partial(jax.jit, static_argnames=('has_bn', 'train', 'update_ema', 'use_base_params',  'use_softplus', 'use_solve', 'stat_loss_only'))
def do_training_step_recon(train_state, labels, init_params, final_params, l2 = 0., has_bn = False, train = True, update_ema = False, ema_decay = 0.995, use_base_params = False, beta = 1., use_softplus = True, img_min = 0, img_max = 1, use_solve = True, stat_loss_only = False):
    
    if has_bn:
        batch_stats = train_state.batch_stats
    else:
        batch_stats = None

    # get_training_loss_l2(train_state.params, images, labels, train_state, l2 = l2, train = train, has_bn = has_bn, batch_stats = batch_stats, use_base_params = use_base_params)
        
    (loss, (new_batch_stats, acc, _)), grad = jax.value_and_grad(get_recon_loss, argnums = 0, has_aux = True)(train_state.params, init_params, final_params, labels, train_state, l2 = l2, train = train, has_bn = has_bn, batch_stats = batch_stats, use_base_params = use_base_params, beta = beta, use_softplus = use_softplus, 
    img_min = img_min, img_max = img_max, use_solve = use_solve, stat_loss_only = stat_loss_only)
    
#     iters_per = 8000
#     index_to_update = (train_state.train_it % (iters_per * train_state.params['images'].shape[0]))//iters_per
    
#     grad_mask = jnp.zeros(shape = [train_state.params['images'].shape[0]])
#     grad_mask = grad_mask.at[index_to_update].set(1.)
    
#     grad['duals'] = grad['duals'] * grad_mask
#     grad['images'] = grad['images'] * grad_mask[:, None, None, None]
        
    if has_bn:
        new_state = train_state.apply_gradients(grads = grad, batch_stats = new_batch_stats['batch_stats'], train_it = train_state.train_it + 1)
    else:
        new_state = train_state.apply_gradients(grads = grad, train_it = train_state.train_it + 1)
    
    if update_ema:
        new_ema_hidden, new_ema_average = get_updated_ema(new_state.params, new_state.ema_hidden, ema_decay, new_state.train_it, order = 1)
        new_state = new_state.replace(ema_average = new_ema_average, ema_hidden = new_ema_hidden)
    
    return new_state, (loss, acc)


# @functools.partial(jax.jit, static_argnames=('has_bn', 'train', 'update_ema', 'use_base_params',  'use_softplus'))
# def do_training_step_recon(train_state, training_batch, direction, net_params, l2 = 0., has_bn = False, train = True, update_ema = False, ema_decay = 0.995, use_base_params = False, use_softplus = True, n_steps = 10000):
    
#     def body_fn(i, val):
#         train