import os
from argparse import ArgumentParser

import numpy as np

import jax

from torch.utils.data import DataLoader

from models import save_model, model_apply, get_model_from_name
from utils.data import get_train_test_datasets
from imaml import train_outer_imaml, init_train_state
from utils.utils import save_image, tensor_to_img, gkern, save_dict_to_file, compute_metrics
from utils.data import ImageDataset
from utils.tasks import get_task_batches


def test_model(state_model_test, outer_state, inner_state, results_folder='results_denoising', grayscale=False):
    r"""
    Test the model on the test dataset.

    :param dict state_model_test: State of the model for the test dataset
    :param dict outer_state: State of the outer (meta) model
    :param dict inner_state: State of the inner model
    :param str results_folder: Path to the folder where the results will be saved
    :param bool grayscale: Whether the images are grayscale or not
    """
    path_test_data = 'pth/to/Set3C/'

    dataset_test = ImageDataset(path_test_data, grayscale=grayscale)
    dataloader = DataLoader(dataset_test, batch_size=1, shuffle=False,
                            drop_last=False)  # We drop last because I cannot yet handle batches with varying sizes

    blur_kernel = gkern(kernlen=7, std=0.5)

    batched_model_forward = jax.vmap(model_apply, in_axes=(None, None, 0, 0))
    batched_model_forward_params = jax.vmap(model_apply, in_axes=(None, 0, 0, 0))

    rng = jax.random.PRNGKey(0)
    mask = jax.random.choice(rng, 2, shape=(1, 1, 256, 256), p=np.asarray([0.5, 0.5]))

    metrics_meta_all = {}
    metrics_inner_all = {}

    for im_index, im in enumerate(dataloader):
        x_target_batch_tasks, x_input_batch_tasks, type_operators_batch_tasks = get_task_batches(rng, im,
                                                                                                 x_target_batch=None,
                                                                                                 kernel=blur_kernel,
                                                                                                 mask=mask)

        x_output_meta = batched_model_forward(state_model_test, outer_state, x_input_batch_tasks,
                                              type_operators_batch_tasks)
        x_output_inner = batched_model_forward_params(state_model_test, inner_state, x_input_batch_tasks,
                                                      type_operators_batch_tasks)

        metrics_meta = compute_metrics(x_target_batch_tasks, x_input_batch_tasks, x_output_meta)
        metrics_inner = compute_metrics(x_target_batch_tasks, x_input_batch_tasks, x_output_inner)

        metrics_meta_all.update({'image '+str(im_index): metrics_meta})
        metrics_inner_all.update({'image '+str(im_index): metrics_inner})

        for task in range(x_input_batch_tasks.shape[0]):

            save_image(tensor_to_img(x_output_meta[task]), results_folder+'x_test_'+str(im_index)+'_output_meta.png')
            save_image(tensor_to_img(x_input_batch_tasks[task]), results_folder+'x_test_'+str(im_index)+'_input_'+str(task)+'_task_'+str(task)+'.png')
            save_image(tensor_to_img(x_output_inner[task]), results_folder+'x_test_'+str(im_index)+'_output_'+str(task)+'_task_'+str(task)+'.png')
            save_image(tensor_to_img(x_target_batch_tasks[task]), results_folder+'x_test_'+str(im_index)+'_target_'+str(task)+'_task_'+str(task)+'.png')

        return metrics_meta, metrics_inner


if __name__ == '__main__':

    parser = ArgumentParser()
    parser.add_argument('--model_name', type=str, default='PDNet')
    parser.add_argument('--results_folder', type=str, default='results_PDNet_4tasks')
    parser.add_argument('--channels', type=int, default=3)
    parser.add_argument('--num_inner_steps', type=int, default=1)
    parser.add_argument('--tasks', nargs='+', default=[0, 1, 2, 3])
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--batch_size_inner', type=int, default=32)
    parser.add_argument('--supervised', type=int, default=1)
    parser.add_argument('--num_layers', type=int, default=20)
    args = parser.parse_args()

    model_name = args.model_name
    results_folder = args.results_folder

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)

    log_filename = results_folder+'training.log'

    rng = jax.random.PRNGKey(0)

    model, model_test, mask, learning_rate = get_model_from_name(rng, model_name, im_size_test=256, im_size=32,
                                                                 channels=args.channels, num_iter=args.num_layers)

    meta_state_test = init_train_state(model_test, rng, (1, args.channels, 256, 256), 0)
    meta_state = init_train_state(model, rng, (1, args.channels, 32, 32), learning_rate)

    # Get datasets
    grayscale = True if args.channels == 1 else False
    dataset_train, dataset_test = get_train_test_datasets(grayscale=grayscale)

    # Training
    meta_state, inner_state, _ = train_outer_imaml(rng, dataset_train, dataset_test, model, meta_state, meta_state_test,
                                                   batch_size_inner=args.batch_size_inner,
                                                   num_inner_steps=args.num_inner_steps, batch_size=args.batch_size, epochs=1001, debug=True, mask=mask,
                                                   results_folder=results_folder, log_filename=log_filename,
                                                   reg_param_inner=1e-3, grayscale=grayscale,
                                                   supervised=args.supervised)

    save_model(meta_state, results_folder+'ckpt_meta.ckpt')
    save_model(inner_state, results_folder+'ckpt_inner.ckpt')

    # Testing the model
    metrics_meta, metrics_inner = test_model(meta_state_test, meta_state, inner_state, results_folder=results_folder,
                                             grayscale=grayscale)

    save_dict_to_file(metrics_meta, results_folder+'test_metrics_meta.txt')
    save_dict_to_file(metrics_inner, results_folder+'test_metrics_inner.txt')
