'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import numpy as np
import random

import os
import argparse

from models import *
from partition import *


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--midx', default=-1, type=int, help='model index')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='/cifar10', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1)

testset = torchvision.datasets.CIFAR10(root='/cifar10', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


# Training
def train(trainloader, epoch, optimizer, scheduler, criterion, model_name=""):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    acc = 100.*correct/total
    print('train/acc', 100.*correct/total)
    print('train/loss', train_loss/(batch_idx+1))

    print('Saving..')
    state = {
        'net': net.state_dict(),
        'acc': acc,
        'epoch': epoch,
    }
    if not os.path.isdir('checkpoint/'+model_name):
        os.mkdir('checkpoint/'+model_name)
    torch.save(state, './checkpoint/' + model_name + '/ckpt.pth')

    return train_loss/(batch_idx+1), 100.*correct/total


def test(testloader, epoch, criterion, model_name=""):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()


    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint/'+model_name):
            os.mkdir('checkpoint/'+model_name)
        torch.save(state, './checkpoint/' + model_name + '/ckpt_best.pth')
        best_acc = acc

    mode = 'test'
    print(mode + '/acc', 100.*correct/total)
    print(mode + '/loss', test_loss/(batch_idx+1))

    return test_loss/(batch_idx+1), 100.*correct/total



if __name__ == "__main__":
    num_epochs = 2
    num_models = 2


    seed_val = 0
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    np.random.seed(seed_val)
    random.seed(seed_val)

    forget_set = random.sample(range(len(trainset)), 1000)
    print(forget_set)

    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')

    print(len(trainset))
    allocations, mapping = allocate_datas(list(range(len(trainset))), num_models, forget_set)
    with open('partitions.pkl', 'wb') as f:
        pickle.dump(allocations, f)
    with open('removed_mapping.pkl', 'wb') as f:
        pickle.dump(mapping, f)


    model_idx = args.midx
    seed = model_idx
    print('seed.....', seed)
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    seed_val = seed
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    np.random.seed(seed_val)
    random.seed(seed_val)

    if model_idx >= 0:
        with open('partitions.pkl', 'rb') as f:
            allocations = pickle.load(f)
        with open('removed_mapping.pkl', 'rb') as f:
            mapping = pickle.load(f)

        trainset_subset = torch.utils.data.Subset(trainset, allocations[model_idx])  
        trainloader = torch.utils.data.DataLoader(trainset_subset, batch_size=128, shuffle=True, num_workers=1)
    else:
        remaining_data = list(set(range(len(trainset))) - set(forget_set))
        subset = random.sample(remaining_data, len(trainset)//2)
        trainset_subset = torch.utils.data.Subset(trainset, subset)  
        trainloader = torch.utils.data.DataLoader(trainset_subset, batch_size=128, shuffle=True, num_workers=1)


    # testset_subset = torch.utils.data.Subset(testset, list(range(0, 10000, 2)))
    # testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

    # Model
    print('==> Building model..')
    # net = VGG('VGG19')
    # net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()
    # net = RegNetX_200MF()
    net = SimpleDLA()
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False


    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(start_epoch, start_epoch + num_epochs):
        tr_loss, tr_acc = train(trainloader, epoch, optimizer, scheduler, criterion, model_name="model_"+str(seed))
        ts_loss, ts_acc = test(testloader, epoch, criterion, model_name="model_"+str(seed))

        scheduler.step()



