import jax
import jax.numpy as jnp
import torch

import torchvision.datasets as dset
import torchvision.transforms as transforms
import neural_tangents as nt

import numpy as np
import flax
import flax.linen as nn
import optax as tx
import neural_tangents.stax as stax

import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'


from typing import Any, Callable, Sequence, Tuple
from flax.training import train_state, checkpoints

import matplotlib.pyplot as plt
import functools
import operator
import fire

import data
from utils import *
import models
import training_utils
import pickle


def main(seed = 0, dataset_name = 'mnist_odd_even', model_width = 1000, train_set_size = 100, output_dir = None, linearize = False, checkpoint_name = 'final_checkpoint', max_iters = 80000, amp_factor = 2, use_solve = False, stat_loss_only = False, ntk_param = False, no_bias = False, init_checkpoint_name = None, save_name = 'reconstruction'):
    if output_dir is not None:
        if not os.path.exists('./{}'.format(output_dir)):
            os.makedirs('./{}'.format(output_dir))

    stat_loss_only = True

    key = jax.random.PRNGKey(seed)
    key, image_key, dual_key = jax.random.split(key, 3)

    _, train_labels, train_mean = data.get_dataset(dataset_name, jax.random.PRNGKey(seed), train_set_size)


    checkpoint_dict = pickle.load(open('./{}/{}.pkl'.format(output_dir, checkpoint_name), 'rb'))

    init_params = checkpoint_dict['init_params']
    final_params = checkpoint_dict['final_params']

    if init_checkpoint_name is not None:
        init_checkpoint_dict = pickle.load(open('./{}/{}.pkl'.format(output_dir, init_checkpoint_name), 'rb'))
        init_params = init_checkpoint_dict['final_params']


    model = models.MLP(width = [model_width, model_width], ntk_param = ntk_param, no_bias = no_bias, output_dim = train_labels.shape[-1])
    net_init, net_apply_base = model.init, model.apply


    if linearize:
        net_apply = get_linear_forward(net_apply_base, init_params)
    else:
        net_apply = net_apply_base

    reconstruction_size = int(amp_factor * train_set_size)

    if 'mnist' in dataset_name:
        init_images = {
            'images': jnp.array(0.2 * jax.random.normal(image_key, shape = [reconstruction_size, 28, 28, 1])),
            'duals': jnp.array(jax.random.uniform(dual_key, shape = [reconstruction_size, train_labels.shape[-1]])) - 0.5,
        }
    elif 'cifar10' in dataset_name:
        init_images = {
            'images': jnp.array(0.2 * jax.random.normal(image_key, shape = [reconstruction_size, 32, 32, 3])),
            'duals': jnp.array(jax.random.uniform(dual_key, shape = [reconstruction_size, train_labels.shape[-1]])) - 0.5,
        }


    train_labels = (jnp.array([-1 for i in range(int(reconstruction_size//2))] + [1 for i in range(int(reconstruction_size//2))]).astype(jnp.float32)).reshape(-1, 1)

    opt = tx.chain(tx.adam(learning_rate=0.02))

    image_train_state = training_utils.TrainStateWithBatchStats.create(apply_fn = net_apply, params = init_images, tx = opt, batch_stats = None, train_it = 0, base_params = None)

    for i in range(max_iters + 1):
        beta = min(10 + 10 ** (i/20000), 200)
        use_softplus = True
        # use_softplus = False
        # beta = 1e6
        image_train_state, (loss, acc) = training_utils.do_training_step_recon(image_train_state, train_labels, init_params, final_params, beta = beta, use_softplus = use_softplus, img_min = 0 - train_mean, img_max = 1 - train_mean, use_solve = use_solve, stat_loss_only = stat_loss_only)

        if i % 10000 == 0:
            print(f'iter: {i}, loss: {loss}')

    output_dict = image_train_state.params
    
    if not stat_loss_only:
        pickle.dump(output_dict, open('./{}/{}_amp_{}_solve_{}.pkl'.format(output_dir, save_name, amp_factor, use_solve), 'wb'))
    else:
        pickle.dump(output_dict, open('./{}/{}_amp_{}_solve_{}_stat_only.pkl'.format(output_dir, save_name, amp_factor, use_solve), 'wb'))

if __name__ == '__main__':
    # main('cifar10')
    fire.Fire(main)