import os
import os.path
import argparse
import random
import datetime
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

import data_load
from tools import createModel, evaluate, predict, train, isSame, NewDataset, adjust_learning_rate
from transformer import transform_train, transform_test, transform_target
np.set_printoptions(formatter={'float': '{: 0.2f}'.format})


parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, help='initial learning rate', default=0.01)
parser.add_argument('--weight_decay', type=float, help='weight_decay for training', default=1e-4)
parser.add_argument('--n_epoch', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--model_dir', type=str, help='dir to save model files', default='model')
parser.add_argument('--dataset', type=str, help='mnist, cifar10, or cifar100', default='mnist')
parser.add_argument('--modelName', type=str, help='9-layer, Lenet, ResNet18, ResNet34', default='default')
parser.add_argument('--noise_type', type=str, help='pairflip, symmetric, instance', default='symmetric')
parser.add_argument('--noise_rate', type=float, help='corruption rate, should be less than 1', default=0.2)
parser.add_argument('--seed', type=int, help='seed number', default=1)
parser.add_argument('--maxCleanNum', type=int, help='Clean round times including the first round', default=3)
args = parser.parse_args()

print(args)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    os.system('nvidia-smi')


# mnist, cifar10, cifar100
if args.dataset == 'mnist':
    input_channel = 1
    num_classes = 10

    if(args.modelName == 'default'):
        args.modelName = 'Lenet'

    train_dataset = data_load.mnist_dataset(True, transform=transform_test(args.dataset), target_transform=transform_target,
                                            noise_rate=args.noise_rate, noise_type=args.noise_type, random_seed=args.seed)
    val_dataset = data_load.mnist_dataset(False, transform=transform_test(args.dataset), target_transform=transform_target,
                                          noise_rate=args.noise_rate, noise_type=args.noise_type, random_seed=args.seed)
    test_dataset = data_load.mnist_test_dataset(transform=transform_test(args.dataset), target_transform=transform_target)
elif args.dataset == 'cifar10':
    input_channel = 3
    num_classes = 10

    if(args.modelName == 'default'):
        args.modelName = 'ResNet18'

    train_dataset = data_load.cifar10_dataset(True, transform=transform_test(args.dataset), target_transform=transform_target,
                                              noise_rate=args.noise_rate, noise_type=args.noise_type, random_seed=args.seed)
    val_dataset = data_load.cifar10_dataset(False, transform=transform_test(args.dataset), target_transform=transform_target,
                                            noise_rate=args.noise_rate, noise_type=args.noise_type, random_seed=args.seed)
    test_dataset = data_load.cifar10_test_dataset(transform=transform_test(args.dataset), target_transform=transform_target)
elif args.dataset == 'cifar100':
    input_channel = 3
    num_classes = 100

    if(args.modelName == 'default'):
        args.modelName = 'ResNet34'

    train_dataset = data_load.cifar100_dataset(True, transform=transform_test(args.dataset), target_transform=transform_target,
                                               noise_rate=args.noise_rate, noise_type=args.noise_type, random_seed=args.seed)
    val_dataset = data_load.cifar100_dataset(False, transform=transform_test(args.dataset), target_transform=transform_target,
                                             noise_rate=args.noise_rate, noise_type=args.noise_type, random_seed=args.seed)
    test_dataset = data_load.cifar100_test_dataset(transform=transform_test(args.dataset), target_transform=transform_target)


def getEvenData(noise_data_original, noise_labels_original, num_classes, precents=None, min_class_number=None):
    if(precents is None):
        precents = np.ones(num_classes)

    label_sizes = np.zeros(num_classes)
    if(min_class_number is None):
        for target_value in range(num_classes):
            index = np.argwhere(noise_labels_original == target_value).squeeze()
            label_sizes[target_value] = index.shape[0]

        min_class_number = label_sizes[np.argmin(label_sizes)]
        print('noise_data_original', noise_data_original.shape, "min_class_number", int(min_class_number), 'label_sizes', label_sizes)

    for target_value in range(num_classes):
        idx = np.argwhere(noise_labels_original == target_value).squeeze()
        need_num = int(min_class_number * precents[target_value])
        noise_labels_original = np.delete(noise_labels_original, idx[need_num:], 0)  # delete by line
        noise_data_original = np.delete(noise_data_original, idx[need_num:], 0)

    for target_value in range(num_classes):
        index = np.argwhere(noise_labels_original == target_value).squeeze()
        label_sizes[target_value] = index.shape[0]

    print('New data size:', len(noise_labels_original), 'every class:', label_sizes)
    return noise_data_original, noise_labels_original

# For predcit new labels
train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=8, shuffle=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=8, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=8, shuffle=False)


# mkdir
model_save_dir = args.model_dir + '/' + args.dataset + '/' + 'noise_rate_%s' % (args.noise_rate)
if not os.path.exists(model_save_dir):
    os.system('mkdir -p %s' % (model_save_dir))


def main():
    best_model_name = None
    best_val_accuracy = 0
    best_model_name2 = None
    best_val_accuracy2 = 0
    
    isClean = True
    lastTimeAccuracy = 0
    maxCleanEpoch = args.n_epoch

    breakNum = 0
    roundNum = 0
    while(breakNum < args.maxCleanNum):
        roundNum += 1
        # predict new labels
        if(best_model_name is not None):
            print("Round：" + str(roundNum) + ", Load " + best_model_name + ", evaluate it and predict labels...")
            cnn = createModel(args.modelName, input_channel, num_classes)
            cnn.load_state_dict(torch.load(best_model_name))

            # Use test dataset to evaluate no noise accuracy, lines could be removed in final version.
            evaluate(test_loader, cnn, num_classes)
            train_labels = predict(train_loader, cnn)
            train_data, train_labels = isSame(train_labels, train_dataset.train_noisy_labels, train_dataset.train_clean_labels,
                                              train_dataset.train_data, num_classes)
        else:
            train_data = train_dataset.train_data
            train_labels = train_dataset.train_noisy_labels

        # Prepare new data loader
        new_dataset = NewDataset(train_data, train_labels, transform_train(args.dataset))
        new_dataset_loader = torch.utils.data.DataLoader(dataset=new_dataset, batch_size=args.batch_size, num_workers=8, shuffle=True)

        # Calculate epoch number
        if(isClean is True):
            print("Clean precedure, run " + str(maxCleanEpoch) + " epochs")
            best_val_accuracy = best_val_accuracy * 0.6
            maxEpoch = maxCleanEpoch

            # Define models
            cnn = createModel(args.modelName, input_channel, num_classes)
        else:
            maxEpoch = args.n_epoch

        if(args.modelName == '9-layer'):
            args.lr = 0.001
            optimizer = torch.optim.Adam(cnn.parameters(), lr=args.lr)
        else:
            optimizer = optim.SGD(cnn.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
            if(isClean is False):
                scheduler = MultiStepLR(optimizer, milestones=[40, 80], gamma=0.1)

        # train
        n_add = 0
        best_around_acc = 0
        isImproved = False
        for epoch in range(maxEpoch):
            time_stamp = datetime.datetime.now()
            print(time_stamp.strftime('%H:%M:%S') + " Epoch", epoch, "begin...")
            train(new_dataset_loader, cnn, optimizer)
            validation_acc, validation_std = evaluate(val_loader, cnn, num_classes)

            if(args.modelName == '9-layer'):
                adjust_learning_rate(optimizer, epoch)
            else:
                if(isClean is False):
                    scheduler.step()

            if(isClean is True):
                n_add += 0.1
                if n_add > 2:
                    print("No improved in 20 epoch, break!")
                    break

            # Save model, because the model has been improved, the predicted labels will get more,
            # then stop and use more labels to train a better model.
            if(validation_acc > best_val_accuracy + n_add):
                filepath = model_save_dir + "/" + args.dataset + "-" + str(roundNum) + "-" + str(epoch) + "-" + str(round(validation_acc, 2)) + ".hdf5"
                best_val_accuracy = validation_acc
                best_model_name = filepath
                print("Save Model:" + best_model_name)
                torch.save(cnn.state_dict(), best_model_name)
                if(best_val_accuracy > best_val_accuracy2):
                    best_val_accuracy2 = best_val_accuracy
                    best_model_name2 = best_model_name

                if(roundNum == 1):
                    # avoid missing the best validation accracy
                    maxCleanEpoch = epoch

                # reset n_add
                n_add = 0

                # For early stop, this could speed up training process
                if(validation_acc > best_around_acc and epoch > 0):
                    best_around_acc = validation_acc
                    if(best_around_acc - lastTimeAccuracy > 0.5 and isClean is False):
                        # New model has improved, so it doesn't need to train longer,
                        # which is better for less noise into model and save time.
                        if(isImproved is False):
                            print("The performance has improved.")

                        isImproved = True

            if(epoch % 20 == 0 and epoch > 0 and isImproved is True and isClean is False):
                print("The performance has improved, go to the next round:" + str(round(best_around_acc-lastTimeAccuracy, 2)))
                break

        # Examine improvements
        if(best_val_accuracy == lastTimeAccuracy):
            breakNum += 1
            isClean = True
            print("Val accrucy were not improved! Num:" + str(breakNum))
        else:
            lastTimeAccuracy = best_val_accuracy
            isClean = False

    # output best model in the all rounds
    print("best_val_accuracy2", str(best_val_accuracy2), ", best_model_name2", best_model_name2)


main()
