import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import os

def data_loader(a, batch_size, is_train=True):
    if a == 0:
        #imagnet
        transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                                 std = [ 0.229, 0.224, 0.225 ]),
        ])

        traindir = os.path.join('/home/iihihiuh/Tools/imagenet/val/')
        if is_train:
            traindir = os.path.join('/home/iihihiuh/Tools/imagenet/train')
        train = datasets.ImageFolder(traindir, transform)
        train_loader = torch.utils.data.DataLoader(
            train, batch_size=batch_size, shuffle=True, num_workers=12)
        return train_loader
    elif a == 1:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        data_set = datasets.CIFAR10(root='/home/yongqin/', train=is_train, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True)
        
        train_loader = torch.utils.data.DataLoader(
            data_set,
            batch_size=batch_size, shuffle=True,
            num_workers=12, pin_memory=True)
        return train_loader
    elif a == 2:
        data_set = datasets.MNIST('/home/yongqin', train=is_train, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
        train_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True,
            num_workers=12, pin_memory=True)


        return train_loader

def get_topk_acc(preds, true_labels, k=5):
    batch_size = len(true_labels)
    assert(len(preds.shape) == 2)
    assert(preds.shape[0] == batch_size)

    preds = preds.argsort(axis=1)[:, -k:][:, ::-1] + 1
    return np.sum(preds[:, 0] == true_labels), \
           np.sum(np.any(preds.T == true_labels, axis=0))

def save_model(model_list, epoch):
    os.mkdir("saved/epoch"+str(epoch))
    for index, model in enumerate(model_list):
        path = "saved/epoch"+str(epoch)+"/model"+str(index)
        torch.save(model.state_dict(), path)

def load_model(model_list, epoch):
    for index, model in enumerate(model_list):
        path = "saved/epoch"+str(epoch)+"/model"+str(index)

        model.load_state_dict(torch.load(path))
        model.eval()

def copy_seqlist(state_dicts, seqs):
    for idx, seq in enumerate(seqs):
        seq.load_state_dict(state_dicts[idx])
        seq.eval()

def copy_sequentials(model_lists):
    state_dicts = []
    for seq in model_lists[0]:
        state_dicts.append(seq.state_dict())

    for idx, models in enumerate(model_lists):
        if idx == 0:
            # skipping the first one
            continue
        copy_seqlist(state_dicts, models)
