import torch
import pickle
import os
import numpy as np
import random

from collections import OrderedDict
from torchmeta.modules import MetaModule
from torchmeta.utils.data import MetaDataset

def apply_grad(model, grad):
    grad_norm = 0
    for p, g in zip(model.parameters(), grad):
        if p.grad is None:
            p.grad = g
        else:
            p.grad += g
        grad_norm += torch.sum(g**2)
    grad_norm = grad_norm ** (1/2)
    return

def mix_grad(grad_list, weight_list):

    mixed_grad = []
    for g_list in zip(*grad_list):
        g_list = torch.stack([weight_list[i] * g_list[i] for i in range(len(weight_list))])
        mixed_grad.append(torch.sum(g_list, dim=0))
    return mixed_grad

def compute_accuracy(logits, targets):
    with torch.no_grad():
        _, predictions = torch.max(logits, dim=1)
        accuracy = torch.mean(predictions.eq(targets).float())
    return accuracy.item()

def tensors_to_device(tensors, device=torch.device('cpu')):
    if isinstance(tensors, torch.Tensor):
        return tensors.to(device=device)
    elif isinstance(tensors, (list, tuple)):
        return type(tensors)(tensors_to_device(tensor, device=device)
            for tensor in tensors)
    elif isinstance(tensors, (dict, OrderedDict)):
        return type(tensors)([(name, tensors_to_device(tensor, device=device))
            for (name, tensor) in tensors.items()])
    else:
        raise NotImplementedError()


def gradient_update_parameters(model,
                               loss,
                               params=None,
                               step_size=0.5,
                               first_order=False):
 
    if not isinstance(model, MetaModule):
        raise ValueError('The model must be an instance of `torchmeta.modules.'
                         'MetaModule`, got `{0}`'.format(type(model)))

    if params is None:
        params = OrderedDict(model.meta_named_parameters())

    grads = torch.autograd.grad(loss,
                                params.values(),
                                create_graph=not first_order)
    
    updated_params = OrderedDict()

    if isinstance(step_size, (dict, OrderedDict)):
        for (name, param), grad in zip(params.items(), grads):
            updated_params[name] = param - step_size[name] * grad
    
    else:
        for (name, param), grad in zip(params.items(), grads):
            updated_params[name] = param - step_size * grad

    return updated_params

def ensure_directory_exists():
    directory = full_path()
    if not os.path.exists(directory):
        os.makedirs(directory)

def full_path(file_name=None):
    file_path = './taskpool' 
    if file_name:
        file_path = os.path.join(file_path, file_name)
    
    return file_path
def save_task_pool(task_pool, file_name):
    ensure_directory_exists() 
    path = full_path(file_name)
    with open(path, 'wb') as f: 
        pickle.dump(task_pool, f)

def load_task_pool(file_name):

    path=full_path(file_name)
    with open(path, 'rb') as f:
        return pickle.load(f)

def task_pool_exists(file_name):

    path = full_path(file_name)
    return os.path.exists(path)
        
def omniglot_collate_fn(batch):
    train_inputs = []
    train_labels = []
    test_inputs = []
    test_labels = []

    for item in batch:
        train_input = item['train'][0]
        if train_input.shape[0] == 1:
            train_input = train_input.squeeze(0)

        train_label = item['train'][1]
        if train_label.dim() > 1:
            train_label = train_label.squeeze()
        train_labels.append(train_label)
        train_inputs.append(train_input)

        test_input = item['test'][0]
        if test_input.shape[0] == 1:
            test_input = test_input.squeeze(0)

        test_label = item['test'][1]
        if test_label.dim() > 1:
            test_label = test_label.squeeze()
        test_labels.append(test_label)
        test_inputs.append(test_input)

    train_inputs = torch.stack(train_inputs)
    train_labels = torch.stack(train_labels)
    test_inputs = torch.stack(test_inputs)
    test_labels = torch.stack(test_labels)

    return {'train': [train_inputs, train_labels], 'test': [test_inputs, test_labels]}

def sinusoid_collate_fn(batch):
    train_inputs = []
    train_labels = []
    test_inputs = []
    test_labels = []

    for item in batch:

        train_input = item['train'][0]
        if train_input.shape[0] == 1 :
            train_input = train_input.squeeze(0)


        train_label = item['train'][1]
        if train_label.shape[0] == 1 :
            train_label = train_label.squeeze(0)
        train_labels.append(train_label)
        train_inputs.append(train_input)

        test_input = item['test'][0]
        if test_input.shape[0] == 1 :
            test_input = test_input.squeeze(0)

        test_label = item['test'][1]
        if test_label.dim() > 1:
            test_label = test_label.squeeze(0)
        test_labels.append(test_label)
        test_inputs.append(test_input)

    train_inputs = torch.stack(train_inputs)
    train_labels = torch.stack(train_labels)
    test_inputs = torch.stack(test_inputs)
    test_labels = torch.stack(test_labels)

    return {'train': [train_inputs, train_labels], 'test': [test_inputs, test_labels]}


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)
    random.seed(seed)
   
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class ToTensor1D(object):
    def __call__(self, array):
        return torch.from_numpy(array.astype('float32'))

    def __repr__(self):
        return self.__class__.__name__ + '()'

