import argparse
import torch
from torchsummary import summary
from xmeta.networks.simple_networks import ConvFeature as Convnet
from xmeta.protonet.protonet import pairwise_distances_logits
from xmeta.utils.data import ImpureTasksets, get_tasksets
from xmeta.utils.seed import set_seed
from xmeta.protonet.protonet\
    import xfast_adapt, ProtonetExplainer, ProtonetExplainerOPA, save_protoexpl
import os
import sys
import datetime


def main(args):
    if args.deterministic:
        torch.use_deterministic_algorithms(True)

    save_dir = os.path.dirname(args.ckpt)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write(" ".join(sys.argv))

    if args.cuda and torch.cuda.device_count():
        print("Using gpu")
        # torch.cuda.manual_seed(args.seed)
        set_seed(args.seed)
        device = torch.device('cuda')
    else:
        set_seed(args.seed)
        device = torch.device('cpu')

    def _preprocess(x):
        d, lbl = x
        x = [d.to(device), lbl.to(device)]
        return x

    set_seed(args.seed)
    if args.dataset == 'omniglot':
        model = Convnet(x_dim=1)
        summary(model, (1, 28, 28))
    else:
        model = Convnet()
        summary(model, (3, 84, 84))
    
    # load trained weights
    assert args.ckpt is not None
    assert os.path.exists(args.ckpt)
    model_name = os.path.basename(args.ckpt)
    model_path = args.ckpt
    model.load_state_dict(torch.load(model_path))
    print(f'loaded {model_path}')
    
    model.to(device)

    if args.num_tasks == -1:
        assert args.save_tasksets is False

    tasksets = get_tasksets(seed=args.seed,
                            name=args.dataset,
                            train_ways=args.train_way,
                            train_samples=args.train_shot + args.train_query,
                            test_ways=args.test_way,
                            test_samples=args.test_shot + args.test_query,
                            num_tasks=args.num_tasks,
                            root='~/data',
                            )
    # set_seed(args.seed)
    tasksets = ImpureTasksets(tasksets, num_tasks=args.num_tasks,
                              ways=args.train_way, shots=args.train_shot,
                              # train_mask_labels=mask_labels,
                              # train_mask_tasks=mask_tasks,
                              train_noise_tasks=args.noise_tasks,
                              # train_shuffle_tasks=shuffle_tasks,
                              # train_dark_tasks=dark_tasks,
                              # train_recolor_tasks=recolor_tasks,
                              # train_bgr_tasks=bgr_tasks,
                              savedir=save_dir,
                              seed=args.seed)
    if args.save_tasksets:  # for debugging
        tasksets.save()
    
    tasks_train = tasksets.train
    explainer_tag = model_name[:-len('.pth')]

    if args.opa:
        explainer = ProtonetExplainerOPA(model=model,
                                         meta_params=list(model.parameters()),
                                         num_hessian_elements=args.num_hessian_elements,
                                         ortho_vectors=True,
                                         # entropy_weight=entropy_weight,
                                         savedir=save_dir,
                                         tag=explainer_tag,
                                         )
    else:
        explainer = ProtonetExplainer(model=model,
                                      meta_params=list(model.parameters()),
                                      savedir=save_dir,
                                      tag=explainer_tag,
                                     )

    for iteration in range(args.num_tasks):
        model.train()
        batch = tasks_train[iteration]
        batch = _preprocess(batch)
        result = xfast_adapt(model,
                             batch,
                             args.train_way,
                             args.train_shot,
                             args.train_query,
                             metric=pairwise_distances_logits,
                             device=device
                             )
        
        error = result['evaluation']['error']
        pred = result['evaluation']['prediction']
        explainer.add_src_test_error(error=error / args.num_tasks, model_output=pred)

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        # print('Meta Train ZeroShot Error', result['train']['error'].item())
        # print('Meta Train ZeroShot Accuracy', result['train']['accuracy'].item())
        print('Meta Train Adaptation Error', result['adaptation']['error'].item())
        print('Meta Train Adaptation Accuracy', result['adaptation']['accuracy'].item())
        print('Meta Train Evaluation Error', result['evaluation']['error'].item())
        print('Meta Train Evaluation Accuracy', result['evaluation']['accuracy'].item())
  
    explainer.normalize_src_test_hessian()
    # save_explainer(explainer, prefix='tmp', postfix=str(args.num_tasks - 1),
    #                model_path=model_path)

    explainer.set_src_param_matrix(taylor_series=False)
    if args.discard_intermediate:
        explainer.discard_intermediate()

    # save explainer
    save_protoexpl(explainer, prefix='einv', model_path=model_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    # parser.add_argument('--meta-batch-size', type=int, default=32)
    # parser.add_argument('--num-iterations', type=int, default=100)
    # parser.add_argument('--ways', type=int, default=5)
    # parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--deterministic', action='store_true')
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--train-way', type=int, default=5)
    parser.add_argument('--train-shot', type=int, default=5)
    parser.add_argument('--train-query', type=int, default=5)
    parser.add_argument('--test-way', type=int, default=5)
    parser.add_argument('--test-shot', type=int, default=5)
    parser.add_argument('--test-query', type=int, default=5)
    parser.add_argument('--save-tasksets', action='store_true')
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--num-hessian-elements', type=int, default=None)
    parser.add_argument('--discard-intermediate', action='store_true')
    parser.add_argument('--opa', action='store_true')
    args = parser.parse_args()

    main(args)
