import torch
import numpy as np
from instance_cifar10 import *
import collections
from matplotlib import pyplot as plt
from model.resnet import *
from complementary_loss import *
import random
from torch.optim.lr_scheduler import MultiStepLR
import argparse
import torchvision.datasets as datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='IDCLL for CIFAR10')

parser.add_argument('-lr', '--learning_rate', help='optimizer\'s learning rate', default=0.01, type=float)
parser.add_argument('-bs', '--batch_size', help='batch_size of ordinary labels.', default=128, type=int)
parser.add_argument('-k', '--k', help='mink', default=3, type=int)
parser.add_argument('-lo', '--loss', help='loss type', choices=['forward_loss', 'scl_nl', 'scl_exp', 'pc_loss','w_loss','porden','nn','ovr_loss'], type=str, required=True)
parser.add_argument('-e', '--epochs', help='number of epochs', type=int, default=200)
parser.add_argument('-wd', '--weight_decay', help='weight decay', default=5e-4, type=float)
parser.add_argument('-se', '--seed', help='seed', default=3, type=int)
parser.add_argument('-pre', '--pretrain', help='pretrain model', choices=['resnet18', 'resnet34', 'vgg16', 'googlenet'], default='resnet18', type=str)
args = parser.parse_args()

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
setup_seed(args.seed)

# evaluation
def train_accuracy(train_loader, model):
    model.eval()
    total, num_samples = 0, 0
    for id, (train_X, train_Y) in enumerate(train_loader):
        train_X, train_Y = train_X.to(device), train_Y.to(device)
        batch_size = train_X.shape[0] 
        outputs = model(train_X)
        _, predicted = torch.max(outputs, 1)
        total += (predicted == train_Y).sum().item()
        num_samples += train_Y.size(0)
    return round(100*total/num_samples, 2)

def test_accuracy(test_loader, model):
    model.eval()
    total, num_samples = 0, 0
    for id, (test_X, test_Y) in enumerate(test_loader):
        test_X, test_Y = test_X.to(device), test_Y.to(device)
        batch_size = test_X.shape[0] 
        outputs = model(test_X)
        _, predicted = torch.max(outputs, 1)
        total += (predicted == test_Y).sum().item()
        num_samples += test_Y.size(0)
    return round(100*total/num_samples, 2)

# train
def train(train_loader, model, optimizer, epoch, lr, me,ccp_com,ccp_mincom):
    train_loss = 0
    model.train()
    total = 0
    for idx, (x, mincom, com, y, id) in enumerate(train_loader):
        y = y.to(device)
        x = x.to(device)
        mincom, com = mincom.to(device), com.to(device)
        batch_size = x.shape[0] 
        y = y.to(torch.int64)
        outputs = model(x)
        if args.loss=='forward_loss':
            loss = forward_loss(outputs,10,mincom)
        elif args.loss=='scl_nl':
            loss = scl_nl(outputs,mincom)
        elif args.loss=='scl_exp':
            loss = scl_exp(outputs,mincom)
        elif args.loss=='pc_loss':
            loss = pc_loss(outputs,10,mincom)
        elif args.loss=='w_loss':
            loss = w_loss(outputs, 10,mincom)
        elif args.loss == 'porden':
            loss = partial_loss(outputs,mincom)
        elif args.loss == 'nn':
            loss = non_negative_loss(outputs,10,mincom,ccp_mincom,beta=0)
        elif args.loss == 'ovr_loss':
            loss = ovr_loss(outputs,mincom)
        else:
            print("no such method")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss = train_loss + loss.item()
        total+=1
    return train_loss/total

train_loader, test_loader, one_hot_mincom, one_hot_com,ccp_com,ccp_mincom = load_cifar10(args.batch_size,args.pretrain,args.k)
model = resnet18().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[40,80],gamma=0.1,last_epoch=-1)
best_acc = 0
ls_best_acc = []
for i in range(args.epochs):
    print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
    loss = train(train_loader, model, optimizer, i, args.learning_rate, me=args.loss,ccp_com=ccp_com,ccp_mincom=ccp_mincom)
    scheduler.step()
    test_nat_acc = test_accuracy(test_loader, model)
    if test_nat_acc>best_acc:
        best_acc = test_nat_acc
    print("epoch:",i,";","test_acc:",test_nat_acc, '; best_acc:', best_acc)
    ls_best_acc.append(best_acc)
np.savetxt("{}".format(method_loss), np.array(ls_best_acc))