import jax
import jax.numpy as jnp
import numpy as np
from copy import deepcopy
from tqdm import tqdm

from jax.flatten_util import ravel_pytree
import optax
from utils import tool, mp

import einops
import torch
import haiku as hk
from functools import partial
from copy import deepcopy

from absl import flags

FLAGS = flags.FLAGS
jax.config.update("jax_enable_x64", True)

def loss_fn(params, state, batch):
    logits = state.apply_fn(
        {'params':params, 'batch_stats': state.batch_stats},
        batch['x'],
        train=False,
    )

    loss = optax.softmax_cross_entropy(logits, batch['y']).sum()
    wd = 0.5 * jnp.sum(jnp.square(ravel_pytree(params)[0]))
    loss_ = loss + FLAGS.weight_decay * wd
    return loss_

def loss_fn_wrapped(last_params, state, batch):
    last_key = 'Dense_0'
    last_params = last_params[last_key]
    wrapped_params = state.params
    # replace last_params of wrapped_params
    fn = lambda module_name, name, value: last_params[name] if module_name == last_key else value
    wrapped_params = hk.data_structures.map(fn, wrapped_params)
    # apply original loss_fn
    loss = loss_fn(wrapped_params, state, batch)
    return loss

def mv_full(hess, vector):
    result = jnp.matmul(vector, hess)
    return result


def compute_influence_full_last(state, data_loader, data_loader_hess, num_classes, num_samples, *args, **kwargs):
    
    num_devices = jax.device_count()
    
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_ = jax.random.split(rng)
        
    state = jax.tree_util.tree_map(lambda x: x.astype(jnp.float64), state)

    last_key = 'Dense_0'

    vec_params, unravel_fn = tool.params_to_vec(state.params[last_key], True)
    last_params = {last_key:state.params[last_key]}

    # compute Hessian
    accum_hess = 0
    hess_fn = jax.jit(jax.hessian(loss_fn_wrapped))
    for batch in tqdm(data_loader_hess):
        index, data, attr, _ = batch
        
        data = einops.rearrange(data, '(n b) c h w -> (n b) h w c', n=num_devices)
        data = jnp.asarray(data.numpy(), dtype=jnp.float64)

        label = attr[:, FLAGS.target_attr_idx]
        label = torch.nn.functional.one_hot(label, num_classes)
        label = jnp.asarray(label.numpy())
        
        batch = {'x':data, 'y':label}
        
        hess = hess_fn(last_params, state, batch)
        
        bias_length = np.prod([s for s in state.params[last_key]['bias'].shape])
        kernel_length = np.prod([s for s in state.params[last_key]['kernel'].shape])
        bias_bias = np.asarray(hess[last_key]['bias'][last_key]['bias']).reshape(bias_length,bias_length)
        bias_kernel = np.asarray(hess[last_key]['bias'][last_key]['kernel']).reshape(bias_length,kernel_length)
        kernel_bias = np.asarray(hess[last_key]['kernel'][last_key]['bias']).reshape(kernel_length,bias_length)
        kernel_kernel = np.asarray(hess[last_key]['kernel'][last_key]['kernel']).reshape(kernel_length,kernel_length)
        bias_cat = np.concatenate((bias_bias, bias_kernel), axis=1)
        kernel_cat = np.concatenate((kernel_bias, kernel_kernel), axis=1)
        hess = np.concatenate((bias_cat, kernel_cat), axis=0)

        param_length = vec_params.shape[0]
        hess = np.reshape(hess, (param_length,param_length))
        accum_hess += hess 
    
    hess = jnp.asarray(accum_hess)
    
    # invert Hessian
    L, U = np.linalg.eigh(hess)
    h_inv = U @ np.diag((L)**(-1)) @ U.T

    h_inv = h_inv.astype(jnp.float32)
    state = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), state)
    last_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), last_params)

    mv_j = jax.jit(mv_full)
    
    dataset_influence = []
    dataset_index = []
    dataset_true_label = []
    dataset_bias_label = []

    def loss_fn_wrapped_grad(last_params, state, batch):
        last_params = unravel_fn(last_params)
        wrapped_params = state.params
        fn = lambda module_name, name, value: last_params[name] if module_name == last_key else value
        wrapped_params = hk.data_structures.map(fn, wrapped_params)
        loss = loss_fn(wrapped_params, state, batch)
        return loss

    grad_fn_j = jax.pmap(jax.vmap(jax.grad(loss_fn_wrapped_grad)))

    batch_size_per_gpu = FLAGS.batch_size // num_devices
    assert FLAGS.batch_size % num_devices == 0
    vec_params_r = jnp.stack([deepcopy(vec_params) for _ in range(batch_size_per_gpu)],axis=0)
    vec_params_r = jnp.stack([deepcopy(vec_params_r) for _ in range(num_devices)],axis=0)
    state_r = jax.tree_util.tree_map(lambda x: jnp.stack([deepcopy(x) for _ in range(batch_size_per_gpu)],axis=0), state)
    state_r = mp.replicate(state_r)
    
    target_split_loader = data_loader
    
    for batch in tqdm(target_split_loader):
        index, data, attr, _ = batch

        # last mini-batch handling
        if data.shape[0] < FLAGS.batch_size:
            assert data.shape[0] % num_devices == 0
            batch_size_per_gpu = data.shape[0] // num_devices
            vec_params_r = jnp.stack([deepcopy(vec_params) for _ in range(batch_size_per_gpu)],axis=0)
            vec_params_r = jnp.stack([deepcopy(vec_params_r) for _ in range(num_devices)],axis=0)
            state_r = jax.tree_util.tree_map(lambda x: jnp.stack([deepcopy(x) for _ in range(batch_size_per_gpu)],axis=0), state)
            state_r = mp.replicate(state_r)
        
        data = einops.rearrange(data, '(n b u) c h w -> n b u h w c', n=num_devices, u=1)
        data = jnp.asarray(data.numpy())
        label = attr[:, 0]
        label = torch.nn.functional.one_hot(label, num_classes).reshape(num_devices, -1, 1, num_classes)
        label = jnp.asarray(label.numpy())     
        batch = {'x':data, 'y':label}
        
        batch_grad = grad_fn_j(vec_params_r, state_r, batch)
        batch_grad = batch_grad.reshape(-1,vec_params.shape[0])
        
        tgt_grad = mv_j(h_inv, batch_grad)
        influence = (tgt_grad * batch_grad).sum(axis=-1)

        dataset_influence.append(np.array(influence))
        dataset_index.append(index.numpy())
        dataset_true_label.append(attr[:, 0].numpy())
        dataset_bias_label.append(attr[:, 1].numpy())
        
    output = {}
    output['influence'] = np.concatenate(dataset_influence, axis=0)
    output['index'] = np.concatenate(dataset_index, axis=0)
    output['true_label'] = np.concatenate(dataset_true_label, axis=0)
    output['bias_label'] = np.concatenate(dataset_bias_label, axis=0)
        
    return output  
