import torch
import random
import numpy as np
import os
import os
import os.path

from torchvision import datasets, transforms

def make_state_id(args):
    state_list = []
    if args.additional_info is not None:
        state_list += [args.additional_info]

    if args.arch == 'VGG' and args.vgg_cut_block is not None:
        state_list += [args.arch+str(int(args.depth_wide) - int(args.vgg_cut_block))]
    elif args.depth_wide is not None:
        state_list += [args.arch+str(args.depth_wide)]
    else:
        state_list += [args.arch]

    state_list += [args.activation_type]

    if args.activation_type == 'shifted_tanh':
        state_list += ['tau'+str(args.tau)]

    state_list += [args.operation_order, 'seed' + str(args.seed), args.dataset]

    if args.weight_decay != args.gamma_decay or args.gamma_decay != args.beta_decay or args.weight_decay != args.beta_decay:
        decay_str = 'weight' + str(args.weight_decay)
        decay_str += 'gamma' + str(args.gamma_decay)
        decay_str += 'beta' + str(args.beta_decay)

        state_list += [decay_str]
    else:
        state_list += [str(args.weight_decay)]

    state_list += [str(args.lr)]

    model_filename = '_'.join(state_list)

    return model_filename

def fix_randomness(seed, cuda_available):
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if cuda_available:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False 

def get_data_loader(dataset, batch_size, test_batch_size, num_workers):
    
    if os.path.isdir('./data'):
        data_dir = './data'
    else:
        data_dir = '/data'

    if dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        data_path = os.path.join(data_dir, 'cifar10')


        train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
        test_data = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=num_workers)

        num_classes = 10
    elif dataset == 'cifar100':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        data_path = os.path.join(data_dir, 'cifar100')

        train_data = datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train)
        test_data = datasets.CIFAR100(root=data_path, train=False, download=True, transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=num_workers)

        num_classes = 100

    elif dataset == 'tinyImageNet':
        transform_train = transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262))
        ])
        transform_test = transforms.Compose([
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262))
        ])
        data_path = os.path.join(data_dir, 'tiny-imagenet')

        train_data = datasets.ImageFolder(os.path.join(data_path, 'train'), transform_train)
        test_data = datasets.ImageFolder(os.path.join(data_path, 'val'), transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)

        num_classes = 200
    elif 'ImageNet' in dataset :

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]), ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]), ])

        data_path = os.path.join(data_dir, 'imagenet')

        num_classes = 1000

        train_data = datasets.ImageFolder(root=os.path.join(data_path, 'train'), transform=transform_train)
        test_data = datasets.ImageFolder(root=os.path.join(data_path, 'val3'), transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    elif dataset == 'cifar100-c':
        corruptions = [
        'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
        'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
        'brightness', 'contrast', 'elastic_transform', 'pixelate',
        'jpeg_compression'
        ]
        
        cifar_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.49139968,  0.48215841,  0.44653091),
                                (0.24703223,  0.24348513,  0.26158784))
            ])

        train_loader = None
        test_loader = {}
        
        for corruptions_names in corruptions:
            print(corruptions_names)
            data_path = os.path.join(data_dir, 'CIFAR-100-C', corruptions_names)
            test_data = datasets.CIFAR100(root=data_path, 
                                train=False, 
                                transform=cifar_transforms,
                                download=True)
            
            data_path = os.path.join(data_dir,'CIFAR-100-C', f"{corruptions_names}.npy")
            label_path = os.path.join(data_dir, 'CIFAR-100-C', "labels.npy")
        
            test_data.data = np.load(data_path)
            test_data.targets = torch.LongTensor(np.load(label_path))
        
            test_loader[corruptions_names] = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            num_classes = 100
        
    else:
        print('Invalid dataset!')
        exit(0)

    return train_loader, test_loader, num_classes

def get_pretrained_weight(model, args):

    pretrained_model = torch.load(args.pretrained, map_location="cuda" if args.cuda else "cpu")

    pretrained_weight = pretrained_model['state_dict']
    model_weight = model.state_dict()

    model_weight_key = list(model.state_dict().keys())

    for ind, key in enumerate(pretrained_weight.keys()):
        weight = pretrained_weight[key]
        if key not in model_weight_key:
            pair_key = model_weight_key[ind]

            model_weight[pair_key] = weight
        else:
            model_weight[key] = weight

    try:
        model.load_state_dict(pretrained_model['state_dict'])
    except:
        print("layer adjusting for data parallel (module)..")
        model.load_state_dict(model_weight)

    return pretrained_model

def save_state(model, acc, model_filename):
    print('==> Saving model ...')
    state = {
        'acc': acc,
        'state_dict': model.module.state_dict(),
    }

    torch.save(state, os.path.join('saved_models/', model_filename))

def print_model_parameters(model):
    # print the number of model parameters
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('Total parameter number:', params, '\n')
