import gc
import argparse
import os
import pickle
from os import path

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
from torchmeta.modules import (MetaModule, MetaConv2d, MetaBatchNorm2d,
                               MetaSequential, MetaLinear)
from torchmeta.datasets import Omniglot, MiniImagenet, TieredImagenet
from torchmeta.transforms import ClassSplitter, Categorical, Rotation
from torchvision.transforms import ToTensor, Resize, Compose
import torchvision.transforms as transforms
from torchmeta.utils.data import BatchMetaDataLoader
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
import hp_search_grid as hpsearch
from utils.cub_downloader import CUB, CARS

def list_to_str(list_arg, delim=' '):
    """Convert a list of numbers into a string.

    Args:
        list_arg: List of numbers.
        delim (optional): Delimiter between numbers.

    Returns:
        List converted to string.
    """
    ret = ''
    for i, e in enumerate(list_arg):
        if i > 0:
            ret += delim
        ret += str(e)
    return ret

def get_subdict(adict, name):
    if adict is None:
        return adict
    tmp = {k[len(name) + 1:]:adict[k] for k in adict if name in k}
    return tmp

def save_performance_summary(args):

    # save some stuff in a pickle for later
    train_epochs = args.epochs

    tp = dict()

    tp["mean_sparcity_best"] = args.mean_sparcity_best
    tp["mean_sparcity_end"] = args.mean_sparcity_end
    tp["best_acc_epoch"] = args.best_acc_epoch
    tp["best_acc"] = args.best_acc
    tp["end_acc"] = args.end_acc
    tp["end_acc_list"] = list_to_str(args.end_acc_list)
    tp["best_acc_list"] = list_to_str(args.best_acc_list)
    tp["finished"] = 1

    # Note, the keywords of this dictionary are defined by the array:
    #   hpsearch._SUMMARY_KEYWORDS

    with open(os.path.join(args.out_dir,
                           hpsearch._SUMMARY_FILENAME), 'w') as f:

        for kw in hpsearch._SUMMARY_KEYWORDS:
            if kw == 'num_train_epochs':
                f.write('%s %d\n' % ('num_train_epochs', train_epochs))
                continue
            else:
                try:
                    f.write('%s %f\n' % (kw, tp[kw]))
                except:
                    f.write('%s %s\n' % (kw, tp[kw]))


def load_data(args):
    meta_dataloader={}
    #-------------Load data----------------------
    if args.dataset=="MiniImagenet":
        dataset_transform = ClassSplitter(shuffle=True,
                                    num_train_per_class=args.num_shots_train,
                                    num_test_per_class=args.num_shots_test)



        test_transform = Compose([Resize(84), ToTensor()])
        train_transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet("data",
                                          transform=train_transform,
                                          target_transform=Categorical(args.num_ways),
                                          num_classes_per_task=args.num_ways,
                                          meta_train=True,
                                          dataset_transform=dataset_transform,
                                          download=True)

        meta_val_dataset = MiniImagenet("data",
                                        transform=test_transform,
                                        target_transform=Categorical(args.num_ways),
                                        num_classes_per_task=args.num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)

        meta_test_dataset = MiniImagenet("data",
                                          transform=test_transform,
                                          target_transform=Categorical(args.num_ways),
                                          num_classes_per_task=args.num_ways,
                                          meta_test=True,
                                          dataset_transform=dataset_transform)


        meta_dataloader["train"] = BatchMetaDataLoader(meta_train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["val"]=BatchMetaDataLoader(meta_val_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["test"]=BatchMetaDataLoader(meta_test_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        feature_size=5*5*args.hidden_size
        input_channels=3

    if args.dataset=="Omniglot":
        dataset_transform = ClassSplitter(shuffle=True,
                                    num_train_per_class=args.num_shots_train,
                                    num_test_per_class=args.num_shots_test)
        #class_augmentations = [Rotation([90, 180, 270])]
        class_augmentations=[]
        transform = Compose([Resize(28), ToTensor()])

        meta_train_dataset = Omniglot("data_o",
                                    transform=transform,
                                    target_transform=Categorical(args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_train=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform,
                                    download=True)
        meta_val_dataset = Omniglot("data_o",
                                    transform=transform,
                                    target_transform=Categorical(args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=dataset_transform)
        meta_test_dataset = Omniglot("data_o",
                                    transform=transform,
                                    target_transform=Categorical(args.num_ways),
                                    num_classes_per_task=args.num_ways,
                                    meta_test=True,
                                    dataset_transform=dataset_transform)

        meta_dataloader["train"] = BatchMetaDataLoader(meta_train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)

        meta_dataloader["val"] = BatchMetaDataLoader(meta_val_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)

        meta_dataloader["test"]=BatchMetaDataLoader(meta_test_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                pin_memory=True)
        feature_size=args.hidden_size
        input_channels=1
    if args.dataset=="TieredImagenet":
        dataset_transform = ClassSplitter(shuffle=True,
                                    num_train_per_class=args.num_shots_train,
                                    num_test_per_class=args.num_shots_test)
        test_transform = Compose([Resize(84), ToTensor()])
        train_transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = TieredImagenet("data",
                                          transform=train_transform,
                                          target_transform=Categorical(args.num_ways),
                                          num_classes_per_task=args.num_ways,
                                          meta_train=True,
                                          dataset_transform=dataset_transform,
                                          download=True)

        meta_val_dataset = TieredImagenet("data",
                                        transform=test_transform,
                                        target_transform=Categorical(args.num_ways),
                                        num_classes_per_task=args.num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)

        meta_test_dataset = TieredImagenet("data",
                                          transform=test_transform,
                                          target_transform=Categorical(args.num_ways),
                                          num_classes_per_task=args.num_ways,
                                          meta_test=True,
                                          dataset_transform=dataset_transform)


        meta_dataloader["train"] = BatchMetaDataLoader(meta_train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["val"]=BatchMetaDataLoader(meta_val_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["test"]=BatchMetaDataLoader(meta_test_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        feature_size=5*5*args.hidden_size
        input_channels=3

    if args.dataset=="CUB":
        dataset_transform = ClassSplitter(shuffle=True,
                                    num_train_per_class=args.num_shots_train,
                                    num_test_per_class=args.num_shots_test)
        test_transform = Compose([Resize((84,84)), ToTensor()])
        train_transform = Compose([Resize((84,84)), ToTensor()])

        meta_train_dataset = CUB("data",
                                  transform=train_transform,
                                  target_transform=Categorical(args.num_ways),
                                  num_classes_per_task=args.num_ways,
                                  meta_train=True,
                                  dataset_transform=dataset_transform,
                                  download=True)

        meta_val_dataset = CUB("data",
                                transform=test_transform,
                                target_transform=Categorical(args.num_ways),
                                num_classes_per_task=args.num_ways,
                                meta_val=True,
                                dataset_transform=dataset_transform)

        meta_test_dataset = CUB("data",
                                  transform=test_transform,
                                  target_transform=Categorical(args.num_ways),
                                  num_classes_per_task=args.num_ways,
                                  meta_test=True,
                                  dataset_transform=dataset_transform)


        meta_dataloader["train"] = BatchMetaDataLoader(meta_train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["val"]=BatchMetaDataLoader(meta_val_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["test"]=BatchMetaDataLoader(meta_test_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        feature_size=5*5*args.hidden_size
        input_channels=3
    if args.dataset=="CARS":
        dataset_transform = ClassSplitter(shuffle=True,
                                    num_train_per_class=args.num_shots_train,
                                    num_test_per_class=args.num_shots_test)
        test_transform = Compose([Resize((84,84)), ToTensor()])
        train_transform = Compose([Resize((84,84)), ToTensor()])

        meta_train_dataset = CARS("data",
                                  transform=train_transform,
                                  target_transform=Categorical(args.num_ways),
                                  num_classes_per_task=args.num_ways,
                                  meta_train=True,
                                  dataset_transform=dataset_transform,
                                  download=True)

        meta_val_dataset = CARS("data",
                                transform=test_transform,
                                target_transform=Categorical(args.num_ways),
                                num_classes_per_task=args.num_ways,
                                meta_val=True,
                                dataset_transform=dataset_transform)

        meta_test_dataset = CARS("data",
                                  transform=test_transform,
                                  target_transform=Categorical(args.num_ways),
                                  num_classes_per_task=args.num_ways,
                                  meta_test=True,
                                  dataset_transform=dataset_transform)


        meta_dataloader["train"] = BatchMetaDataLoader(meta_train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["val"]=BatchMetaDataLoader(meta_val_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        meta_dataloader["test"]=BatchMetaDataLoader(meta_test_dataset,
                                                        batch_size=args.test_batch_size,
                                                        shuffle=True,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True)

        feature_size=5*5*args.hidden_size
        input_channels=3
    return meta_dataloader, feature_size, input_channels
