import numpy as np
import jax.numpy as jnp
import haiku as hk
import os
import optax
import math

from jax import random, tree_util
from tqdm import tqdm
from scipy.stats import t

from comln.datasets.leo import LEOMiniImagenet
from comln.datasets.miniimagenet import MiniImagenet
from comln.datasets.tieredimagenet import TieredImagenet
from comln.metalearners.comln import COMLNArguments, COMLN, COMLNMetaParameters
from comln.utils.hk_utils import save_params_state
from comln.modules.resnet import ResNet12
from comln.modules.conv import Conv4


def main(args):
    if args.model == 'resnet12':
        model_cls = ResNet12
    elif args.model == 'conv4':
        model_cls = Conv4
    else:
        raise ValueError(f'Unknown dataset `{args.model}`')

    if args.dataset == 'miniimagenet':
        dataset_cls = MiniImagenet
    elif args.dataset == 'tieredimagenet':
        dataset_cls = TieredImagenet
    else:
        raise ValueError(f'Unknown dataset `{args.dataset}`')

    # Model
    @hk.transform_with_state
    def model(inputs, is_training):
        return model_cls()(inputs, is_training)

    # Meta-dataset
    meta_train_dataset = dataset_cls(
        args.folder,
        batch_size=args.batch_size,
        shots=args.shots,
        ways=args.ways,
        test_shots=args.test_shots,
        size=args.num_batches,
        split='train',
        seed=args.seed,
        download=True
    )

    meta_val_dataset = dataset_cls(
        args.folder,
        batch_size=10,
        shots=args.shots,
        ways=args.ways,
        test_shots=args.test_shots,
        size=100,
        split='val',
        seed=args.seed
    )

    # Meta-learner
    scheduler = optax.piecewise_constant_schedule(args.meta_lr, {
        int(20 * 1000 / 3): 0.06,
        int(40 * 1000 / 3): 0.012,
        int(50 * 1000 / 3): 0.0024
    })
    optimizer = optax.multi_transform({
        'model': optax.sgd(scheduler, momentum=0.9, nesterov=True),
        'classifier': optax.sgd(scheduler, momentum=0.9, nesterov=True),
        't_final': optax.sgd(args.meta_lr_t, momentum=0.9, nesterov=True)
    }, COMLNMetaParameters(model='model', classifier='classifier', t_final='t_final'))

    metalearner = COMLN.from_args(args.ways, args.metalearner,
        model=model, optimizer=optimizer)

    # Initialization
    key = random.PRNGKey(args.seed)
    meta_train_dataset.reset()
    batch = next(iter(meta_train_dataset))
    params, state = metalearner.init(key, batch['train'].inputs[0], True)

    # Train
    meta_train_dataset.reset()
    best_accuracy, best_params, best_state, n_patience = None, None, None, 0
    for idx, batch in enumerate(tqdm(meta_train_dataset, desc='Train')):
        params, state, results = metalearner.train_step(params, state, batch['train'], batch['test'], True)

        if (idx + 1) % args.log_freq == 0:
            val_results = metalearner.evaluate(params, state, meta_val_dataset, False)
            n_patience += 1

            val_accuracy = jnp.mean(val_results['after/accuracy'], axis=0).item()
            if (best_accuracy is None) or (val_accuracy > best_accuracy):
                best_params, best_state = params, state
                best_accuracy = val_accuracy
                n_patience = 0
                # Save best meta-parameters & state

        if (args.patience > 0) and (n_patience >= args.patience):
            print(f'No improvement after {args.patience} steps. Stop training')
            break

    # Save final model

    # Meta-test dataset
    meta_test_dataset = dataset_cls(
        args.folder,
        batch_size=10,
        shots=args.shots,
        ways=args.ways,
        test_shots=args.test_shots,
        size=100,
        split='test',
        seed=args.seed
    )

    meta_test_logs = metalearner.evaluate(best_params, best_state, meta_test_dataset, False)
    meta_test_loss = np.asarray(meta_test_logs['after/loss'])
    meta_test_accuracy = np.asarray(meta_test_logs['after/accuracy'])
    t_95_sqrt_n = t.ppf(1 - 0.05 / 2, df=meta_test_loss.size - 1) / math.sqrt(meta_test_loss.size)
    results = {
        'meta-test/loss/mean': np.mean(meta_test_loss),
        'meta-test/loss/ci95': t_95_sqrt_n * np.std(meta_test_loss),

        'meta-test/accuracy/mean': np.mean(meta_test_accuracy),
        'meta-test/accuracy/ci95': t_95_sqrt_n * np.std(meta_test_accuracy)
    }
    print(results)


if __name__ == '__main__':
    from simple_parsing import ArgumentParser

    parser = ArgumentParser('COMLN - Continuous-Time Meta-Learning')

    # General
    general = parser.add_argument_group('General')
    general.add_argument('--folder', type=str, required=False,
        help='data folder.')
    general.add_argument('--dataset', type=str, default='miniimagenet',
        choices=['miniimagenet', 'tieredimagenet'],
        help='dataset (default: %(default)s).')
    general.add_argument('--ways', type=int, default=5,
        help='number of classes per task (N in "N-way", default: %(default)s).')
    general.add_argument('--shots', type=int, default=5,
        help='number of training examples per class (k in "k-shot", default: %(default)s).')
    general.add_argument('--test_shots', type=int, default=15,
        help='number of test examples per class (default: %(default)s).')
    general.add_argument('--model', type=str,
        choices=['conv4', 'resnet12'], default='resnet12',
        help='type of model (default: %(default)s).')

    # Metalearner
    parser.add_arguments(COMLNArguments, dest='metalearner')

    # Optimization
    optim = parser.add_argument_group('Optimization')
    optim.add_argument('--batch_size', type=int, default=25,
        help='number of tasks in a batch of tasks (default: %(default)s).')
    optim.add_argument('--num_batches', type=int, default=100,
        help='number of batch of tasks (default: %(default)s).')
    optim.add_argument('--meta_lr', type=float, default=1e-3,
        help='learning rate for the meta-optimizer (optimization of the outer '
        'loss). The default optimizer is Adam (default: %(default)s).')
    optim.add_argument('--meta_lr_t', type=float, default=1e-2,
        help='learning rate for the meta-optimizer for t_final. The '
        'default optimizer if SGD (default: %(default)s).')

    # Miscellaneous
    misc = parser.add_argument_group('Miscellaneous')
    misc.add_argument('--seed', type=int, default=1,
        help='random seed (default: %(default)s).')
    misc.add_argument('--log_freq', type=int, default=100,
        help='frequency of logs (default: %(default)s).')
    misc.add_argument('--patience', type=int, default=15,
        help='patience for early stopping (default: %(default)s).')

    args = parser.parse_args()
    main(args)
