import os
from argparse import ArgumentParser

import jax
import jax.numpy as jnp
import orbax.checkpoint

from torch.utils.data import DataLoader

from models import uPDNetMRI
from utils.data import get_train_test_datasets
from finetune import train_finetune_MRI_single_step
from finetune.finetune_train_MRI import get_optimizer

from flax.training import train_state
from utils.data import ToyfastMRI

if __name__ == '__main__':

    parser = ArgumentParser()
    parser.add_argument('--model_name', type=str, default='uPDNet')
    parser.add_argument('--results_folder', type=str, default=None)
    parser.add_argument('--channels', type=int, default=3)
    parser.add_argument('--task', type=int, default=0)
    parser.add_argument('--num_iter', type=int, default=20)
    parser.add_argument('--supervised', type=int, default=0)
    parser.add_argument('--random_init', type=int, default=0)
    args = parser.parse_args()

    model_name = args.model_name
    results_folder = args.results_folder

    if args.supervised:
        source_folder_str = 'results_uPDNet_1ch_imeta_' + str(args.num_iter) + '_supervised'
    else:
        source_folder_str = 'results_uPDNet_1ch_imeta_' + str(args.num_iter) + '_supervised'

    if args.random_init == 1:
        if results_folder is None:
            if args.supervised:
                results_folder = 'finetune_MRI_' + str(args.num_iter) + 'it_supervised_random_init/'
            else:
                results_folder = 'finetune_MRI_' + str(args.num_iter) + 'it_unsupervised_random_init/'
    else:
        if results_folder is None:
            if args.supervised:
                results_folder = 'finetune_MRI_' + str(args.num_iter) + 'it_supervised/'
            else:
                results_folder = 'finetune_MRI_' + str(args.num_iter) + 'it_unsupervised/'

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)

    log_filename = results_folder+'training.log'

    print('Training logs will be saved in {}'.format(log_filename))

    rng = jax.random.PRNGKey(0)

    # Create the dataset
    dataset = ToyfastMRI(train=True)

    # Create a DataLoader to iterate through the dataset
    batch_size = 1
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Loop through the dataset
    for batch_idx, (input_im, meas, mask, backproj) in enumerate(dataloader):
        mask = jnp.array(mask)[0]
        break

    model = uPDNetMRI(mask=mask, im_size=320, num_iter=args.num_iter, channels=1)

    rng = jax.random.PRNGKey(0)
    init_type_operator = jnp.array([3])
    init_input = jax.random.normal(rng, (1, 1, 320, 320))
    u_init = jax.random.normal(rng, (1, 40, 320, 320))
    y_init = model.forward_op(init_input, mask)
    variables = model.init(rng, y_init, init_type_operator, init_input, u_init)

    if args.random_init == 0:
        pth_ckpt = '/pth/to/ckpt_meta.ckpt'
        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        variables = orbax_checkpointer.restore(pth_ckpt)
        reg_param = 1e-3
    else:
        reg_param = 0


    # # Set training state
    learning_rate = 1e-3
    optimizer = get_optimizer(learning_rate, clip_params_norm=1)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=variables['params']
    )

    # Get datasets
    grayscale = True if args.channels == 1 else False
    dataset_train, dataset_test = get_train_test_datasets(fastMRI=True)

    # Training
    state, info = train_finetune_MRI_single_step(rng, dataset_train, state, variables['params'],
                                             batch_size=2, num_steps=200, debug=True,
                                             results_folder=results_folder, log_filename=log_filename,
                                             grayscale=grayscale, supervised=args.supervised, reg_param=reg_param)
