import os
from argparse import ArgumentParser

import jax
import jax.numpy as jnp
import orbax.checkpoint

from models import uPDNetSR
from utils.data import get_train_test_datasets
from finetune import train_finetune_SR_single_step
from finetune.finetune_train_SR import get_optimizer

from flax.training import train_state


if __name__ == '__main__':

    parser = ArgumentParser()
    parser.add_argument('--model_name', type=str, default='uPDNet')
    parser.add_argument('--results_folder', type=str, default=None)
    parser.add_argument('--channels', type=int, default=3)
    parser.add_argument('--num_iter', type=int, default=20)
    parser.add_argument('--task', type=int, default=0)
    parser.add_argument('--supervised', type=int, default=0)
    parser.add_argument('--random_init', type=int, default=0)
    args = parser.parse_args()

    model_name = args.model_name
    results_folder = args.results_folder

    if args.supervised:
        source_folder_str = 'results_uPDNet_1ch_imeta_'+str(args.num_iter)+'_supervised'
    else:
        source_folder_str = 'results_uPDNet_1ch_imeta_'+str(args.num_iter)+'_unsupervised'

    if args.random_init==1:
        if results_folder is None:
            if args.supervised:
                results_folder = 'finetune_'+str(args.num_iter)+'it_supervised_random_init/'
            else:
                results_folder = 'finetune_'+str(args.num_iter)+'it_unsupervised_random_init/'
    else:
        if results_folder is None:
            if args.supervised:
                results_folder = 'finetune_'+str(args.num_iter)+'it_supervised/'
            else:
                results_folder = 'finetune_'+str(args.num_iter)+'it_unsupervised/'

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)

    log_filename = results_folder+'training.log'

    print('Training logs will be saved in {}'.format(log_filename))

    rng = jax.random.PRNGKey(0)

    sr_factor = 2
    model = uPDNetSR(im_size=32, num_iter=args.num_iter, channels=1, sr_factor=sr_factor)

    rng = jax.random.PRNGKey(0)
    init_type_operator = jnp.array([3])
    init_input = jax.random.normal(rng, (1, 1, 32, 32))
    u_init = jax.random.normal(rng, (1, 40, 32, 32))
    y_init = model.forward_op(init_input, init_type_operator)
    variables = model.init(rng, y_init, init_type_operator, init_input, u_init)

    if args.random_init == 0:
        pth_ckpt = '/pth/to/ckpt_meta.ckpt'
        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        variables = orbax_checkpointer.restore(pth_ckpt)
        reg_param = 1e-3
    else:
        reg_param = 0

    if not args.random_init and args.supervised:
        num_steps = 10001
    else:
        num_steps = 200

    # # Set training state
    learning_rate = 1e-3
    optimizer = get_optimizer(learning_rate, clip_params_norm=10)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=variables['params']
    )

    # Get datasets
    grayscale = True if args.channels == 1 else False
    dataset_train, dataset_test = get_train_test_datasets(fastMRI=False, grayscale=True)

    # Training
    state, info = train_finetune_SR_single_step(rng, dataset_train, state, variables['params'],
                                                batch_size=2, num_steps=num_steps, debug=True,
                                                results_folder=results_folder, log_filename=log_filename,
                                                grayscale=grayscale, sr_factor=sr_factor, supervised=args.supervised,
                                                reg_param=reg_param)
