#!/usr/bin/env python
# coding: utf-8

import jax
import jax.lax
import numpy as np
import jax.numpy as jnp

import torch
from torchvision import datasets, transforms
from tqdm import trange, tqdm
import os
import PIL.Image

import haiku as hk
import optax

from typing import Any, Iterable, Mapping, NamedTuple, Tuple
import atexit
import resnet_cifar
import tree
import shutil
import argparse

from collections import defaultdict
from datetime import datetime
import jsonlines
import time

from jax_resnet_cifar.fcn import BasicFCN
from jaxopt.tree_util import *
from jaxopt.linear_solve import solve_gmres, solve_cg
from jaxopt import IterativeRefinement
from jax import vmap

from functools import partial
import jaxopt
import copy

from tqdm import tqdm as Tqdm
import pathlib
import shutil

parser = argparse.ArgumentParser()
parser.add_argument('-l', '--lsub_id', metavar='lsub_id', type=int,  default=-1) # to parallelize
parser.add_argument('-f', '--loss_type', metavar='loss_type', type=str,  default='FRM')
parser.add_argument('-u', '--loss_space', metavar='loss_space', type=str,  default='positives')
parser.add_argument('-n', '--num_data', metavar='num_data', type=int,  default=8)
parser.add_argument('-s', '--seed', metavar='seed', type=int,  default=0)
parser.add_argument('-e', '--epochs', metavar='epochs', type=int,  default=10001)
parser.add_argument('-d', '--dataset', metavar='dataset', type=str,  default='cifar')
parser.add_argument('-a', '--name', metavar='name', type=str,  default='')
parser.add_argument('-r', '--seedrange', metavar='seedrange', type=str,  default='')

def tree_dot(tree_x, tree_y):
  return tree_sum(tree_map(lambda x, y : jnp.sum(x*y), tree_x, tree_y))

def hvp(f, primals, tangents):
  return jax.jvp(jax.grad(f), primals, tangents)[1]

class FLAGS(NamedTuple): # From vanilla CIFAR code
    KEY = jax.random.PRNGKey(1)
    BATCH_SIZE = 32
    DATA_ROOT = '.'
    LOG_ROOT = './cifar_logs/'
    MAX_EPOCH = 2000
    INIT_LR = 5e-3
    N_WORKERS = 8
    MNIST_MEAN = (0.1307,)
    MNIST_STD = (0.3081,)
    CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
    CIFAR10_STD = (0.2023, 0.1994, 0.2010)
    CIFAR100_MEAN = (0.5071, 0.4867, 0.4408)
    CIFAR100_STD = (0.2675, 0.2565, 0.2761)
    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD = (0.229, 0.224, 0.225)
    NUM_DATA = 50000
    LOSS = 'ERM'
    LR_SCHEDULE = 'UNIFORM'
    BN_DECAY_RATE = 0.979


os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['JAX_PLATFORM_NAME'] = 'gpu'

def tprint(obj):
    tqdm.write(obj.__str__())

class TrainState(NamedTuple):
    params: hk.Params
    state: hk.State
    opt_state: optax.OptState


class MultiEpochsDataLoader(torch.utils.data.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)


class _RepeatSampler(object):
    """ Sampler that repeats forever.
    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)


def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


class ArrayNormalize(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        assert isinstance(arr, np.ndarray), f'Input should be ndarray. Got {type(arr)}.'
        assert arr.ndim >= 3, f'Expected array to be a image of size (..., H, W, C). Got {arr.shape}.'

        dtype = arr.dtype
        mean = np.asarray(self.mean, dtype=dtype)
        std = np.asarray(self.std, dtype=dtype)
        if (std == 0).any():
            raise ValueError(
                f'std evaluated to zero after conversion to {dtype}, leading to division by zero.')
        if mean.ndim == 1:
            mean = mean.reshape(1, 1, -1)
        if std.ndim == 1:
            std = std.reshape(1, 1, -1)
        arr -= mean
        arr /= std
        return arr

class ArrayDenormalize(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        assert isinstance(arr, np.ndarray), f'Input should be ndarray. Got {type(arr)}.'
        assert arr.ndim >= 3, f'Expected array to be a image of size (..., H, W, C). Got {arr.shape}.'

        dtype = arr.dtype
        mean = np.asarray(self.mean, dtype=dtype)
        std = np.asarray(self.std, dtype=dtype)
        if (std == 0).any():
            raise ValueError(
                f'std evaluated to zero after conversion to {dtype}, leading to division by zero.')
        if mean.ndim == 1:
            mean = mean.reshape(1, 1, -1)
        if std.ndim == 1:
            std = std.reshape(1, 1, -1)
        arr *= std
        arr += mean
        return arr

class ToArray(torch.nn.Module):
    '''convert image to float and 0-1 range'''
    dtype = np.float32

    def __call__(self, x):
        assert isinstance(x, PIL.Image.Image)
        x = np.asarray(x, dtype=self.dtype)
        x /= 255.0
        return x


def l2_loss(params):
    # l2_params = jax.tree_util.tree_leaves(params)
    l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(
        params) if 'batchnorm' not in mod_name]
    return 0.5 * sum(jnp.sum(jnp.square(p)) for p in l2_params)


def forward(images, is_training: bool):
    net = BasicFCN(channels=4, bias_h=32, bias_w=32, stride=4)
    return net(images)

def random_split_like_tree(rng_key, target=None, treedef=None):
    if treedef is None:
        treedef = jax.tree_structure(target)
    keys = jax.random.split(rng_key, treedef.num_leaves)
    return jax.tree_unflatten(treedef, keys)

def tree_random_like(rng_key, target, extra_dims=(), mul=0.003):
    return tree_random_normal_like(rng_key, target, extra_dims, mul)

def tree_random_normal_like(rng_key, target, extra_dims=(), mul=0.003):
    keys_tree = random_split_like_tree(rng_key, target)
    return jax.tree_map(
        lambda l, k: mul*jax.random.normal(k, extra_dims+l.shape, l.dtype),
        target,
        keys_tree,
    )

@partial(jax.jit, static_argnames=['loss_space'])
def ERM_loss(p, X, y, loss_space):
    pred = model.apply(p, X, is_training=True)
    if loss_space == 'positives':
        res = jnp.mean(jnp.heaviside(y, 0.5)*(pred-y)**2)
    elif loss_space == 'negatives':
        res = jnp.mean(jnp.heaviside(-y, 0.5)*(pred-y)**2)
    else:
        print(loss_space)
        raise NotImplemented
    return res

@partial(jax.jit, static_argnames=['loss_space'])
def ERM_loss_w_aux(all_p, X, y, loss_space):
    return ERM_loss(all_p['params'], X, y,loss_space=loss_space), 0.

def get_params(init_p, x,y, steps, inner_lr=1e-1):
    return jax.lax.fori_loop(0, steps, 
            lambda _, p : tree_map(lambda a, b : a-inner_lr*b, p,
                jax.grad(ERM_loss)(p, x, y)), init_p)

def error_norms(p, X, y):
    return jnp.linalg.norm((model.apply(p, X, is_training=True)-y).reshape(X.shape[0], -1), axis=1)

def FRM_loss(all_params, X, y, loss_space):
    params, l, s_logit = all_params['params'], jax.lax.stop_gradient(jnp.exp(all_params['log_l'])), jax.lax.stop_gradient(jax.nn.sigmoid(all_params['logit']))
    params_ridge, l_ridge, s_logit_ridge = jax.lax.stop_gradient(all_params['params']), jnp.exp(all_params['log_l']), jax.nn.sigmoid(all_params['logit'])
    N = X.shape[0]
    matvec_H = lambda v : tree_add(tree_scalar_mul(l, v),
            tree_scalar_mul(s_logit,
            hvp(lambda pp : ERM_loss(pp, X, 
            model.apply(params, X, is_training=True),
            loss_space=loss_space),
            (params,), (v,))))
    matvec_HNone = lambda _, v : matvec_H(v)
    aux = jax.vmap(lambda x : jax.grad(model.apply)(params, x, is_training=True))(X)
    sols = jax.vmap(lambda x : solve_cg(matvec_H, x))(aux) # gmres is sometimes unstable
    g_errors = jax.vmap(lambda sol, x : tree_l2_norm(tree_sub(matvec_H(sol), x)))(sols, aux)
    gmres_error = jnp.mean(g_errors)
    g_Hinv_g = jax.vmap(lambda x,sol : tree_vdot(x, sol))(aux, sols)
    err = y-model.apply(params, X, is_training=True)
    per_elt_loss = err**2/g_Hinv_g + jnp.log(g_Hinv_g)
    maximum_likelihood_loss = jnp.mean(per_elt_loss)

    '''
    The problem is over-parameterized so we have to estimate a meaningful Ridge coefficient l.
    This coefficient is used to compute the Hessian inversion.

    We cannot do it with the max-likelihood loss 
    because the job of l is to represent exp(-E[...]) well,
    not to have a higher maximum likelihood by changing the distribution over functions.

    What we want is (H+l*Id) to be a good approximation of exp(-E[L(p+delta, p)]),
    so we literally do MSE on this, avoiding any backpropagation w.r.t. p.
    One subtle thing is that we have to do leave-one-out cross-validation, or else
    H actually over-estimates expected error, not under-estimates. We want it to be
    representative of unseen points.
    '''
    def run_QP(pp,ll,s_log,XX,x,yy):
        umatvec_H = lambda v : tree_add(tree_scalar_mul(ll, v),
                tree_scalar_mul(s_log,
                hvp(lambda p : ERM_loss(p, XX, 
                model.apply(pp, XX, is_training=True),
                loss_space=loss_space), (pp,), (v,))))
        umatvec_HNone = lambda _, v : umatvec_H(v)
        matvec_A = lambda _, v : tree_vdot(jax.grad(lambda p : model.apply(
            p, x, is_training=True).reshape())(jax.lax.stop_gradient(pp)), v)
        ECQP = jaxopt.EqualityConstrainedQP(matvec_Q=umatvec_HNone, matvec_A=matvec_A, 
                solve=jaxopt.linear_solve.solve_normal_cg)
        return jax.lax.stop_gradient(ECQP.run(None, (None, tree_zeros_like(pp)),
                (x, yy-model.apply(pp, x, is_training=True))).params.primal)

    #smallest perturbation according to H that fits QP
    deltas = jax.lax.stop_gradient(jax.vmap(partial(run_QP, params_ridge,
        jax.lax.stop_gradient(l_ridge), jax.lax.stop_gradient(s_logit_ridge), X))(X,y))

    per_elt_matvec_H = lambda v,x : ( #computes Hessian at x and returns Hv
        hvp(lambda pp : ERM_loss(pp, x, 
            model.apply(params_ridge, x, is_training=True),
            loss_space=loss_space),
            (params_ridge,), (v,)))
    # Same as matvec_H, but now stopping gradient w.r.t. all but l
    matvec_H_sg = lambda v : hvp(lambda pp : 
            ERM_loss(pp, X, model.apply(params_ridge, X, is_training=True), loss_space=loss_space),
            (jax.lax.stop_gradient(params_ridge),), (jax.lax.stop_gradient(v),))
    all_but = lambda v,x : tree_add(
            tree_scalar_mul(l_ridge, jax.lax.stop_gradient(v)),
            tree_scalar_mul(s_logit_ridge,
                tree_add(
                    tree_scalar_mul(N/(N-1.), jax.lax.stop_gradient(matvec_H_sg(v))),
                    tree_scalar_mul(-1./(N-1.), jax.lax.stop_gradient(per_elt_matvec_H(v,x))))))
    compute_all_but_norm = lambda v,x : tree_vdot(v, all_but(v, x))
    fake_deltas = jax.vmap(lambda d : tree_random_like(FLAGS.KEY, d))(deltas)
    fake_deltas = jax.vmap(lambda d, fd: tree_scalar_mul(
        tree_l2_norm(d)/tree_l2_norm(fd), fd))(fake_deltas, deltas)
    all_but_norms = jax.vmap(compute_all_but_norm)(fake_deltas, X)

    # E_x[L(p+delta, p)] for each delta
    per_point_deviations = jax.vmap(lambda x, delta : ERM_loss(tree_add(params_ridge, delta), x,
        model.apply(params_ridge, x, is_training=True), loss_space=loss_space))(
                X, fake_deltas)
        
    deviations = per_point_deviations
    # ridge_loss is not a regularization, only is used to determine the ridge coefficient.
    ridge_loss = ((jnp.sqrt(all_but_norms/2.)-jnp.sqrt(jax.lax.stop_gradient(deviations)))**2).mean()

    return maximum_likelihood_loss+ridge_loss, (gmres_error, l)


LOSS_FNS = {'ERM': ERM_loss_w_aux, 'FRM': FRM_loss}
@partial(jax.jit, static_argnames=['loss_fn', 'loss_space'])
def train_step(X, true_params, all_params, opt_state, e_key, loss_fn, loss_space):
    pert = tree_random_like(e_key, true_params, (X.shape[0],))
    y = jax.vmap(lambda p, inp: model.apply(p,inp,is_training=True))(
            tree_add(true_params, pert), X)
    (loss, aux), grads = jax.value_and_grad(LOSS_FNS[loss_fn], has_aux = True)(all_params, X, y,loss_space)
    deltas, opt_state = optimizer.update(grads, opt_state)
    all_params = optax.apply_updates(all_params, deltas)
    return all_params, opt_state, loss

@partial(jax.jit, static_argnames=['loss_fn', 'loss_space'])
def test_step(X, true_params, params, e_key, loss_fn, loss_space):
    pert = tree_random_like(e_key, true_params, (X.shape[0],))
    y = jax.vmap(lambda p, inp: model.apply(p,inp,is_training=True))(
            tree_add(true_params, pert), X)
    return LOSS_FNS[loss_fn](params, X, y,loss_space=loss_space)[0]

def main(args):
    global model
    global optimizer
    global FLAGS
    args_seed = args.seed
    print('num_data: ', args.num_data)
    args_num_data = args.num_data
    args_loss_type = args.loss_type
    FLAGS.DATA_ROOT = '.'
    args_dataset = 'cifar'
    print(f"seed {args_seed} loss {args_loss_type} space {args.loss_space}")
    torch.manual_seed(seed)
    np.random.seed(args.seed)

    FLAGS.KEY = jax.random.PRNGKey(args_seed)
    FLAGS.NAME = args.name
    FLAGS.SEED = args_seed
    FLAGS.NUM_DATA = args_num_data
    FLAGS.LOSS = args_loss_type
    FLAGS.LOSS_SPACE = args.loss_space
    FLAGS.MAX_EPOCH = args.epochs
    FLAGS.BATCH_SIZE = 64
    FLAGS.INIT_LR = {'FRM': 1e-5, 'ERM': 1e-3}[FLAGS.LOSS]
    transform_train = transforms.Compose([
        ToArray(),
        ArrayNormalize(FLAGS.CIFAR10_MEAN, FLAGS.CIFAR10_STD),
    ])
    transform_test = transforms.Compose([
        ToArray(),
        ArrayNormalize(FLAGS.CIFAR10_MEAN, FLAGS.CIFAR10_STD),
    ])
    print(FLAGS.DATA_ROOT)
    train_dataset = datasets.CIFAR10(
        FLAGS.DATA_ROOT, train=True, download=True, transform=transform_train)
    MAX_SAMPLES = 60000 if args_dataset == 'mnist' else 50000
    start_train = np.random.randint(MAX_SAMPLES-args_num_data)
    end_train = start_train + args_num_data
    test_dataset = datasets.CIFAR10(
        FLAGS.DATA_ROOT, train=False, transform=transform_test)
    start_test = np.random.randint(10000-max(FLAGS.BATCH_SIZE, args_num_data))
    end_test = start_test + max(FLAGS.BATCH_SIZE, args_num_data)
    train_dataset = torch.utils.data.Subset(train_dataset, list(range(start_train, end_train)))
    print(f"Data: {start_train}..{end_train}    {start_test}..{end_test}")
    print(f"FLAGS: {FLAGS.__dict__}", flush=True)
    test_dataset = torch.utils.data.Subset(test_dataset, 
            list(range(start_test, end_test)))
    train_loader = MultiEpochsDataLoader(
        train_dataset,
        batch_size=FLAGS.BATCH_SIZE,
        shuffle=False, #it's full batch training
        drop_last=FLAGS.BATCH_SIZE<FLAGS.NUM_DATA,
        num_workers=FLAGS.N_WORKERS,
        collate_fn=numpy_collate,
    )
    test_loader = MultiEpochsDataLoader(
        test_dataset,
        batch_size=FLAGS.BATCH_SIZE*2,
        shuffle=False,
        drop_last=False,
        num_workers=FLAGS.N_WORKERS,
        collate_fn=numpy_collate,
    )
    ## INITIALIZE MODEL ##
    model = hk.transform(forward)
    model = hk.without_apply_rng(model)

    sample_input = jnp.ones((1, 32, 32, 3))
    true_params = model.init(FLAGS.KEY, sample_input, is_training=True)
    params = model.init(FLAGS.KEY+1, sample_input, is_training=True)
    n_params = sum([p.size for p in jax.tree_util.tree_leaves(params)])
    tprint(n_params)

    # Specify learning rate schedule.
    learning_rate_fn = lambda lr : optax.cosine_decay_schedule(
        init_value=lr,
        decay_steps=len(train_loader) * FLAGS.MAX_EPOCH,
        alpha=0.0
        )

    optimizer = optax.multi_transform({
        'params' : optax.sgd(learning_rate_fn(FLAGS.INIT_LR), momentum=0.9, nesterov=False),
        'log_l'  : optax.sgd(learning_rate_fn(1e-1), momentum=0.9, nesterov=False),
        'logit' : optax.sgd(learning_rate_fn(1e-6), momentum=0.9, nesterov=False)},
        param_labels ={'params':'params', 'log_l':'log_l', 'logit':'logit'})

    log_l = jnp.zeros(1)
    logit = 10*jnp.ones(1)
    all_params = {'params': params, 'log_l': log_l, 'logit': logit}
    opt_state = optimizer.init(all_params)

    epoch_keys = jax.random.split(FLAGS.KEY, 2*FLAGS.MAX_EPOCH)

    res = dict((k, FLAGS.__dict__[k]) for k in ['BATCH_SIZE', 'DATA_ROOT', 'MAX_EPOCH', 'INIT_LR',
        'N_WORKERS', 'NUM_DATA', 'LOSS', 'LR_SCHEDULE', 'BN_DECAY_RATE', 'SEED', 'LOSS_SPACE', 'NAME'])
    log_l = jnp.zeros(1)
    train_losses = []
    test_losses = []
    recorded_epochs = []
    for epoch in range(FLAGS.MAX_EPOCH):
        for X, target in train_loader:
            all_params, opt_state, loss = train_step(X, true_params, all_params, opt_state, 
                    epoch_keys[0], FLAGS.LOSS, loss_space=FLAGS.LOSS_SPACE)
            if args.do_print and epoch % 10 == 0: 
                print(f"Train[{epoch}]: ", loss, all_params['log_l'], all_params['logit'])
            if not args.do_print and (epoch % 1000 == 0 or epoch == FLAGS.MAX_EPOCH-1): 
                print(f"Train[{epoch}]: ", loss, all_params['log_l'], all_params['logit'])
                train_losses.append(loss.item())

            if epoch == FLAGS.MAX_EPOCH-1: # Save results in last epoch
                res['train_FRM'] = test_step(X, true_params, all_params, epoch_keys[0], loss_fn='FRM',
                        loss_space='positives').item()
                res['train_ERM'] = test_step(X, true_params, all_params, epoch_keys[0], 
                        loss_fn='ERM', loss_space='positives').item()
                res['train_FRM_negatives'] = test_step(X, true_params, all_params, epoch_keys[0], 
                        loss_fn='FRM', loss_space='negatives').item()
                res['train_ERM_negatives'] = test_step(X, true_params, all_params, epoch_keys[0], 
                        loss_fn='ERM', loss_space='negatives').item()
                res['train_losses'] = train_losses
        for X, target in test_loader:
            if epoch%10!=0 and epoch!=FLAGS.MAX_EPOCH-1: continue #skip unnecessary tests
            test_loss = test_step(X, true_params, all_params, epoch_keys[1], loss_fn='ERM', 
                    loss_space=args.loss_space)
            if args.do_print and epoch % 10 == 0: print(f"Test[{epoch}]: ", test_loss)
            if not args.do_print and (epoch % 1000 == 0 or epoch == FLAGS.MAX_EPOCH-1): 
                print(f"Test[{epoch}]: ", test_loss)
                test_losses.append(test_loss.item())
                recorded_epochs.append(epoch)
            if epoch == FLAGS.MAX_EPOCH-1:
                res['test_FRM'] = test_step(X, true_params, all_params, epoch_keys[1], 
                        loss_fn='FRM', loss_space='positives').item()
                res['test_ERM'] = test_step(X, true_params, all_params, epoch_keys[1], 
                        loss_fn='ERM', loss_space='positives').item()
                res['test_FRM_negatives'] = test_step(X, true_params, all_params, epoch_keys[1], 
                        loss_fn='FRM', loss_space='negatives').item()
                res['test_ERM_negatives'] = test_step(X, true_params, all_params, epoch_keys[1], 
                        loss_fn='ERM', loss_space='negatives').item()
                res['test_losses'] = test_losses
                res['recorded_epochs'] = recorded_epochs
    return res
if __name__ == '__main__':
    args = parser.parse_args()
    args.do_print = False
    if args.lsub_id >= 0:
        howManyToDo = 25  # How many different datasets to try sequentially.
        args.seedrange = f'{args.lsub_id*howManyToDo},{(args.lsub_id+1)*howManyToDo}'
    if args.name == '':
        args.has_name = False
        args.name = datetime.now().strftime('fdt_%Y_%m_%d_%H_%M_%S')
    else: args.has_name = True

    NUM_DATA = [8]
    LOSSES = [('FRM', 'positives'), ('FRM', 'negatives'), ('ERM', 'positives'), ('ERM', 'negatives')]
    args.num_data = 8
    outfile = f'logs_functional_data/res_{args.name}'
    RES = {}
    for loss in LOSSES: RES[loss] = []
    if args.seedrange == '': start, finish = 0, 50
    else: start, finish = int(args.seedrange.split(',')[0]), int(args.seedrange.split(',')[1])
    for seed in Tqdm(range(start, finish)):
        for loss in LOSSES:
            args.seed = seed
            (args.loss_type, args.loss_space) = loss
            res = main(args)
            print(res)
            if args.has_name:
                with jsonlines.open(outfile, mode='a') as writer:
                    writer.write({'time': datetime.now().strftime('%Y/%m/%d, %H:%M:%S'), **res})
