import argparse
import torch
from torchsummary import summary
from xmeta.networks.simple_networks import ConvFeature as Convnet
from xmeta.protonet.protonet import pairwise_distances_logits
from xmeta.utils.data import ImpureTasksets, get_tasksets
from xmeta.utils.seed import set_seed
from xmeta.protonet.protonet import xfast_adapt, fast_adapt
import os
import sys
import numpy as np


def main(args):
    if args.deterministic:
        torch.use_deterministic_algorithms(True)

    if args.ckpt is not None:
        save_dir = os.path.dirname(args.ckpt)
    else:
        save_dir = './cache/tmp'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    result_file_name = 'result_' + os.path.basename(__file__) + '.txt'
    result_file_path = os.path.join(save_dir, result_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))

    if args.cuda and torch.cuda.device_count():
        print("Using gpu")
        # torch.cuda.manual_seed(args.seed)
        set_seed(args.seed)
        device = torch.device('cuda')
    else:
        set_seed(args.seed)
        device = torch.device('cpu')

    def _preprocess(x):
        d, lbl = x
        x = [d.to(device), lbl.to(device)]
        return x

    set_seed(args.seed)
    if args.dataset == 'omniglot':
        model = Convnet(x_dim=1)
        summary(model, (1, 28, 28))
    else:
        model = Convnet()
        summary(model, (3, 84, 84))
    
    # load trained weights
    if args.ckpt is not None:
        assert os.path.exists(args.ckpt)
        model_path = args.ckpt
        model.load_state_dict(torch.load(model_path))
        print(f'loaded {model_path}')
    
    model.to(device)

    tasksets = get_tasksets(seed=args.seed,
                            name=args.dataset,
                            train_ways=args.train_way,
                            train_samples=args.train_shot + args.train_query,
                            test_ways=args.test_way,
                            test_samples=args.test_shot + args.test_query,
                            num_tasks=args.num_tasks,
                            root='~/data',
                            )
    # set_seed(args.seed)
    tasksets = ImpureTasksets(tasksets, num_tasks=args.num_tasks,
                              ways=args.train_way, shots=args.train_shot,
                              # train_mask_labels=mask_labels,
                              # train_mask_tasks=mask_tasks,
                              train_noise_tasks=args.noise_tasks,
                              # train_shuffle_tasks=shuffle_tasks,
                              # train_dark_tasks=dark_tasks,
                              # train_recolor_tasks=recolor_tasks,
                              # train_bgr_tasks=bgr_tasks,
                              savedir=save_dir,
                              seed=args.seed)
    
    tasks_train = tasksets.train
    tasks_test = tasksets.test

    def _test(task_set, m):
        errors = []
        accuracies = []
        for iteration in range(args.num_tasks):
            batch = task_set[iteration]
            batch = _preprocess(batch)
            result = xfast_adapt(m,
                                 batch,
                                 args.train_way,
                                 args.train_shot,
                                 args.train_query,
                                 metric=pairwise_distances_logits,
                                 device=device
                                 )
            
            errors.append(result['evaluation']['error'].item())
            accuracies.append(result['evaluation']['accuracy'].item()) 
            print('\n')
            print('Iteration', iteration)
            # print('Meta Train ZeroShot Error', result['train']['error'].item())
            # print('Meta Train ZeroShot Accuracy', result['train']['accuracy'].item())
            print('Meta Train Adaptation Error', result['adaptation']['error'].item())
            print('Meta Train Adaptation Accuracy', result['adaptation']['accuracy'].item())
            print('Meta Train Evaluation Error', result['evaluation']['error'].item())
            print('Meta Train Evaluation Accuracy', result['evaluation']['accuracy'].item())
            avg_error = np.mean(errors)
            avg_acc = np.mean(accuracies)
            print(f'average error: {avg_error}')
            print(f'average accuracy: {avg_acc}')

        return avg_error, avg_acc
    
    model.eval()
    test_error, test_acc = _test(tasks_test, m=model)
    # [Note]
    # Running updates of computed means and variances in batch normalization layers
    # are essential to reproduce overfitted performances
    model.train()
    train_error, train_acc = _test(tasks_train, m=model)
    
    result_str_train = f'train_error: {train_error}, train_acc: {train_acc}'
    result_str_test = f'test_error: {test_error}, test_acc: {test_acc}'
    print(result_str_train)
    print(result_str_test)
    
    with open(result_file_path, mode='a') as f:
        f.write("\n" + "\n".join([result_str_train, result_str_test]))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--deterministic', action='store_true')
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--train-way', type=int, default=5)
    parser.add_argument('--train-shot', type=int, default=5)
    parser.add_argument('--train-query', type=int, default=5)
    parser.add_argument('--test-way', type=int, default=5)
    parser.add_argument('--test-shot', type=int, default=5)
    parser.add_argument('--test-query', type=int, default=5)
    parser.add_argument('--save-tasksets', action='store_true')
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    parser.add_argument('--noise-tasks', type=int, default=None)
    args = parser.parse_args()

    main(args)
