import jax
import jax.numpy as jnp
import torch

import torchvision.datasets as dset
import torchvision.transforms as transforms
import neural_tangents as nt

import numpy as np
import flax
import flax.linen as nn
import optax as tx
import neural_tangents.stax as stax

import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'


from typing import Any, Callable, Sequence, Tuple
from flax.training import train_state, checkpoints

import matplotlib.pyplot as plt
import functools
import operator
import fire

import data
from utils import *
import models
import training_utils
import pickle


def main(seed = 0, dataset_name = 'mnist_odd_even', n_epochs = 1e6, lr = 1e-3, model_width = 1000, train_set_size = 100, output_dir = None, linearize = False, loss_checkpoints = [], ntk_param = False, no_bias = False, iter_checkpoints = [], distilled_data_dir = None, use_dp = False, clip_grad_norm = 2, grad_noise_ratio = 0.01, use_adam = False, second_seed = None):
    if output_dir is not None:
        if not os.path.exists('./{}'.format(output_dir)):
            os.makedirs('./{}'.format(output_dir))

        with open('./{}/config.txt'.format(output_dir), 'a') as config_file:
            config_file.write(repr(locals()))

    train_images, train_labels, train_mean = data.get_dataset(dataset_name, jax.random.PRNGKey(seed), train_set_size)
    
    if distilled_data_dir is not None:
        distilled_dict = pickle.load(open(f'./{distilled_data_dir}/distillation_result.pkl', 'rb'))
        train_images = distilled_dict['distilled_images']
        train_labels = distilled_dict['distilled_labels']

    key = jax.random.PRNGKey(seed)
    
    print(train_labels.shape)

    model = models.MLP(width = [model_width, model_width], ntk_param = ntk_param, no_bias = no_bias, output_dim = train_labels.shape[-1])
    net_init, net_apply_base = model.init, model.apply

    if second_seed is None:
        init_params = net_init(key, train_images)['params'].unfreeze()
    else:
        init_params = net_init(jax.random.PRNGKey(second_seed), train_images)['params'].unfreeze()

    if linearize:
        net_apply = get_linear_forward(net_apply_base, init_params)
    else:
        net_apply = net_apply_base
    
    if use_adam:
        opt = tx.adam(lr)
    elif not use_dp:
        opt = tx.sgd(lr, momentum = 0.9)
    else:
        opt = tx.dpsgd(lr, clip_grad_norm, grad_noise_ratio, seed, momentum = 0.9)


    model_train_state = training_utils.TrainStateWithBatchStats.create(apply_fn = net_apply, params = init_params, tx = opt, batch_stats = None, train_it = 0, base_params = None)

    loss = np.inf

    for i in range(int(n_epochs)):
        model_train_state, (loss, acc) = training_utils.do_training_step(model_train_state, {'images': train_images, 'labels': train_labels}, use_dp = use_dp)
        if len(loss_checkpoints) > 0 and loss < loss_checkpoints[0]:
            output_dict = {
                'init_params': init_params,
                'final_params': model_train_state.params,
                'final_loss': loss,
                'iter': i,
            }
            print(f'saving checkpoint at loss {loss_checkpoints[0]}')
            if output_dir is not None:
                pickle.dump(output_dict, open('./{}/loss_{}_checkpoint.pkl'.format(output_dir, loss_checkpoints[0]), 'wb'))

            loss_checkpoints.pop(0)

        if len(iter_checkpoints) > 0 and i == iter_checkpoints[0]:
            output_dict = {
                'init_params': init_params,
                'final_params': model_train_state.params,
                'final_loss': loss,
                'iter': i,
            }
            print(f'saving checkpoint at iter {iter_checkpoints[0]}')
            if output_dir is not None:
                pickle.dump(output_dict, open('./{}/iter_{}_checkpoint.pkl'.format(output_dir, iter_checkpoints[0]), 'wb'))

            iter_checkpoints.pop(0)
        
        if i % 10000 == 0:
            print(f'iter: {i}, loss: {loss}')
            (val, _), grad = jax.value_and_grad(training_utils.get_training_loss_l2, has_aux = True)(multiply_by_scalar(model_train_state.params, 1), train_images, train_labels, model_train_state)


        if loss < 1e-10:
            print("Loss is really small, exiting early")
            break


    output_dict = {
        'init_params': init_params,
        'final_params': model_train_state.params,
        'final_loss': loss,
        'iter': i,
    }

    training_dict = {
        'train_images': train_images,
        'train_labels': train_labels,
        'train_mean': train_mean
    }

    print('done')

    if output_dir is not None:
        pickle.dump(output_dict, open('./{}/final_checkpoint.pkl'.format(output_dir), 'wb'))
        pickle.dump(training_dict, open('./{}/training_set.pkl'.format(output_dir), 'wb'))

if __name__ == '__main__':
    # main('cifar10')
    fire.Fire(main)