import os
import os.path
import argparse
import random
import datetime
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR

import torchvision
import torchvision.transforms as transforms

np.seterr(divide='ignore', invalid='ignore')
np.set_printoptions(linewidth=np.inf)
np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
print("PyTorch version:", torch.__version__)
os.system('nvidia-smi')

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, help='clothing1m', default='clothing1m')
parser.add_argument('--num_classes', type=int, default=14)
parser.add_argument('--validation_type', type=str, help='validation type', default='clean')
parser.add_argument('--model_dir', type=str, help='dir to save model files', default='model')
parser.add_argument('--data_root', type=str, help='data location', default='data/Clothing1M_Official/')
parser.add_argument('--lr', type=float, help='initial learning rate', default=0.005)
parser.add_argument('--weight_decay', type=float, help='weight_decay for training', default=0.001)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--pretrain', type=str, help='pretrain', default='Yes')
parser.add_argument('--n_round', type=int, help='max run round', default=3)
parser.add_argument('--n_epoch', type=int, default=30)
parser.add_argument('--update_interval', type=int, default=5)
parser.add_argument('--beta', type=float, help='beta', default=0.45)
parser.add_argument('--seed', type=int, help='seed number', default=1)
args = parser.parse_args()
print(args)


class Clothing1M_Dataset(Dataset):
    def __init__(self, data, labels, root_dir, transform=None, target_transform=None):
        self.train_data = np.array(data)
        self.train_labels = np.array(labels)
        self.root_dir = root_dir
        self.length = len(self.train_labels)

        if transform is None:
            self.transform = transforms.ToTensor()
        else:
            self.transform = transform

        self.target_transform = target_transform
        print("NewDataset length:", self.length)

    def __getitem__(self, index):
        img_paths, target = self.train_data[index], self.train_labels[index]

        img_paths = os.path.join(self.root_dir, img_paths)
        img = Image.open(img_paths).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return self.length

    def getData(self):
        return self.train_data, self.train_labels


def target_transform(label):
    label = np.array(label, dtype=np.int)
    target = torch.from_numpy(label).long()
    return target


def createModel(pretrained):
    if(pretrained == 'Yes'):
        pretrain = True
    else:
        pretrain = False
    
    model = torchvision.models.resnet50(pretrained=pretrain)
    model.fc = nn.Linear(2048, args.num_classes)

    if torch.cuda.is_available:
        model = model.cuda()
    return model


def evaluate(test_loader, model, loss_func, isTest=False):
    model.eval()
    total = 0
    test_loss = 0
    correct = 0
    class_correct = np.zeros(args.num_classes)
    class_total = np.zeros(args.num_classes)
    class_pred = np.zeros(args.num_classes)

    with torch.no_grad():
        for images, labels in test_loader:
            if torch.cuda.is_available:
                images = images.cuda()
                labels = labels.cuda()

            logits = model(images)
            loss = loss_func(logits, labels)
            test_loss += loss.item()

            outputs = F.softmax(logits, dim=1)
            _, pred = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (pred == labels).sum()

            preds = pred.cpu()
            labels = labels.cpu()
            for i in range(labels.size(0)):
                class_total[labels[i]] += 1
                class_pred[preds[i]] += 1
                if(preds[i] == labels[i]):
                    class_correct[labels[i]] += 1

    loss = 100 * (test_loss / total)
    acc = 100 * float(correct) / float(total)
    std = np.std(100 * class_correct / class_total)
    acc_category = np.around(100 * class_correct / class_total, decimals=2)
    precision_category = np.around(100 * class_correct / class_pred, decimals=2)

    if(isTest):
        print(getTime(), 'Test Loss: {:.2f}, Acc: {:.2f}, Std: {:.2f}'.format(loss, acc, std))
    else:
        print(getTime(), 'Val Loss: {:.2f}, Acc: {:.2f}, Std: {:.2f}'.format(loss, acc, std))

    return loss, acc, std, acc_category, precision_category


def predict(train_loader, model):
    model.eval()
    preds = np.array([])
    rates = np.array([])

    with torch.no_grad():
        for images, labels in train_loader:
            if torch.cuda.is_available:
                images = images.cuda()

            logits = model(images)
            outputs = F.softmax(logits, dim=1)
            rate, pred = torch.max(outputs.data, 1)
            rates = np.concatenate((rates, rate.to("cpu", torch.float).numpy()), axis=0)
            preds = np.concatenate((preds, pred.to("cpu", torch.int).numpy()), axis=0)

    return preds.astype(int).tolist()


def train(train_loader, model, loss_func, optimizer):
    model.train()

    train_total = 0.
    train_correct = 0.
    train_loss = 0.
    for i, (images, labels) in enumerate(train_loader):
        if torch.cuda.is_available:
            images = images.cuda()
            labels = labels.cuda()
        
        logits = model(images)
        loss = loss_func(logits, labels)
        pred = torch.max(logits, 1)[1]

        train_loss += loss.item()
        train_correct += (pred == labels).sum()
        train_total += labels.size(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss = 100 * train_loss / train_total
    acc = 100 * train_correct / train_total

    print(getTime(), 'Train Loss: {:.2f}, Acc: {:.2f}'.format(loss, acc))
    return float(acc), float(loss)


def getTime():
    time_stamp = datetime.datetime.now()
    return time_stamp.strftime('%H:%M:%S')


def evaluate_share(model, share_loader, share_nosiy_labels, share_clean_labels):
    noisy_labels = share_nosiy_labels
    clean_labels = share_clean_labels
    predicted_labels = predict(share_loader, model)

    total_number = 0
    correct_number = 0
    matrix = np.zeros([args.num_classes, args.num_classes], dtype=int)
    class_correct = list(0. for i in range(args.num_classes))
    class_total = list(0. for i in range(args.num_classes))

    for i in range(len(predicted_labels)):
        if(predicted_labels[i] == noisy_labels[i]):
            total_number += 1
            class_total[predicted_labels[i]] += 1
            matrix[predicted_labels[i]][clean_labels[i]] += 1

            if(predicted_labels[i] == clean_labels[i]):
                correct_number += 1
                class_correct[predicted_labels[i]] += 1
                
    accracy = round(100 * correct_number / total_number, 2)
    print('Same labels number:', total_number, accracy)
    print(matrix)
    for i in range(args.num_classes):
        if(class_total[i] > 0):
            print('Label precision of %5s: %.2f  %d' % (i, 100 * class_correct[i] / class_total[i], class_total[i]))


def combinateModels(modelList, model_best_scores, modelsIndexs, dataset):
    labels = []
    data = []
    label_sizes = []
    imagePaths, noise_labels = dataset.getData()
    print("Model indexs:", modelsIndexs)
    for j in set(modelsIndexs):
        if j != 0:
            alist = np.argwhere(modelsIndexs == j)
            print("Load " + modelList[int(j)] + ", label classes: " + str(alist.squeeze().tolist()))
            model = createModel(args.pretrain)
            model.load_state_dict(torch.load(modelList[int(j)]))

            for i in alist:
                labels_index = np.argwhere(noise_labels == i).squeeze()
                get_data = np.take(imagePaths, labels_index).squeeze()
                get_labels = np.take(noise_labels, labels_index).squeeze()
                pred_data, pred_labels, pred_rates = predictByTarget(get_data, get_labels, model, i)

                data.extend(pred_data.tolist())
                labels.extend(pred_labels.tolist())
                label_sizes.append(len(pred_labels))

            del model
        else:
            alist = np.argwhere(modelsIndexs == j)
            print("Not found model, label classes: " + str(alist.squeeze().tolist()))
            for i in alist:
                labels_index = np.argwhere(noise_labels == i).squeeze()
                get_data = np.take(imagePaths, labels_index).squeeze()
                get_labels = np.take(noise_labels, labels_index).squeeze()

                data.extend(get_data.tolist())
                labels.extend(get_labels.tolist())
                label_sizes.append(len(get_labels))

    print('combinate label_sizes', label_sizes)
    return np.array(data), np.array(labels)


def predictByTarget(get_data, get_labels, model, target):
    model.eval()
    preds = np.array([])
    rates = np.array([])

    transform_test = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    
    # Prepare new data loader
    new_dataset = Clothing1M_Dataset(get_data, get_labels, args.data_root, transform_test, target_transform)
    new_dataset_loader = DataLoader(dataset=new_dataset, batch_size=args.batch_size, num_workers=32, shuffle=False)

    with torch.no_grad():
        for images, labels in new_dataset_loader:
            if torch.cuda.is_available:
                images = images.cuda()
                labels = labels.cuda()

            logits = model(images)
            outputs = F.softmax(logits, dim=1)
            rate, pred = torch.max(outputs.data, 1)
            pred = pred.to("cpu", torch.int).numpy()
            rate = rate.to("cpu", torch.float).numpy()

            preds = np.concatenate((preds, pred), axis=0)
            rates = np.concatenate((rates, rate), axis=0)

    labels_index = np.argwhere(preds == target).squeeze()
    data = np.take(get_data, labels_index).squeeze()
    preds = np.take(preds, labels_index).squeeze()
    rates = np.take(rates, labels_index).squeeze()

    return data, preds, rates


def main():
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    model_save_dir = args.model_dir + '/' + args.dataset
    if not os.path.exists(model_save_dir):
        os.system('mkdir -p %s' % (model_save_dir))
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop(256, padding=32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    transform_test = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    # Load data file
    kvDic = np.load(args.data_root + 'Clothing1m-data.npy', allow_pickle=True).item()
    
    # Prepare train data loader
    original_train_data = kvDic['train_data']
    original_train_labels = kvDic['train_labels']
    shuffle_index = np.arange(len(original_train_labels), dtype=int)
    np.random.shuffle(shuffle_index)
    original_train_data = original_train_data[shuffle_index]
    original_train_labels = original_train_labels[shuffle_index]

    share_data = kvDic['share_data']
    share_nosiy_labels = kvDic['share_noisy_labels']
    share_clean_labels = kvDic['share_clean_labels']
    share_dataset = Clothing1M_Dataset(share_data, share_nosiy_labels, args.data_root, transform_test, target_transform)
    share_loader = DataLoader(dataset=share_dataset, batch_size=args.batch_size, num_workers=32, shuffle=False)

    val_data = kvDic['clean_val_data']
    val_labels = kvDic['clean_val_labels']
    val_nums = np.zeros(args.num_classes, dtype=int)
    for item in val_labels:
        val_nums[item] += 1
    print("val categroy mean", np.mean(val_nums, dtype=int), "category", val_nums, "precent", val_nums / np.mean(val_nums))
    
    test_data = kvDic['test_data']
    test_labels = kvDic['test_labels']
    test_dataset = Clothing1M_Dataset(test_data, test_labels, args.data_root, transform_test, target_transform)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=32, shuffle=False)

    # Prepare new data loader
    if(args.validation_type == 'noise'):
        nosie_len = int(len(original_train_labels) * 0.9)
        train_data = original_train_data[:nosie_len]
        train_labels = original_train_labels[:nosie_len]
        val_data = original_train_data[nosie_len:]
        val_labels = original_train_labels[nosie_len:]
    else:
        train_data = original_train_data
        train_labels = original_train_labels

    new_dataset = Clothing1M_Dataset(train_data, train_labels, args.data_root, transform, target_transform)
    new_dataset_loader = DataLoader(dataset=new_dataset, batch_size=args.batch_size, num_workers=32, shuffle=True)
    predict_dataset = Clothing1M_Dataset(train_data, train_labels, args.data_root, transform_test, target_transform)
    val_dataset = Clothing1M_Dataset(val_data, val_labels, args.data_root, transform_test, target_transform)
    val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=32, shuffle=False)

    # Loss function
    train_nums = np.zeros(args.num_classes, dtype=int)
    for item in train_labels:
        train_nums[item] += 1
    print("train categroy mean", np.mean(train_nums, dtype=int), "category", train_nums, "precent", np.mean(train_nums) / train_nums)

    if torch.cuda.is_available:
        class_weights = torch.FloatTensor(np.mean(train_nums) / train_nums * val_nums / np.mean(val_nums)).cuda()
        loss_func = nn.CrossEntropyLoss(weight=class_weights).cuda()
    else:
        class_weights = torch.FloatTensor(np.mean(train_nums) / train_nums * val_nums / np.mean(val_nums))
        loss_func = nn.CrossEntropyLoss(weight=class_weights)

    best_model_name = ""
    best_val_accuracy = 0
    best_test_accuracy = 0

    experiment_name = args.dataset + "_pretrian_" + str(args.pretrain) + "_lr_" + str(args.lr) + "_epoch_" + str(args.n_epoch) + "_beta_" + str(args.beta) + "_seed_" + str(args.seed)
    writer = SummaryWriter("runs/" + datetime.datetime.now().strftime('%b%d_%H:%M_') + experiment_name)
    
    for roundNum in range(args.n_round):
        print(getTime(), "Round " + str(roundNum + 1) + " Create a new model...")
        model = createModel(args.pretrain)
        modelList = np.array([""])
        model_best_scores = np.zeros(args.num_classes, dtype=float)
        model_indexs = np.zeros(args.num_classes, dtype=int)

        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)
        
        # train and validate
        for epoch in range(args.n_epoch):
            print(getTime(), "Epoch", epoch + 1, "begin...")
            train_acc, train_loss = train(new_dataset_loader, model, loss_func, optimizer)
            val_loss, validation_acc, validation_std, val_class_acc, val_class_precision = evaluate(val_loader, model, loss_func, False)
            test_loss, test_acc, test_std, test_class_acc, test_class_precision = evaluate(test_loader, model, loss_func, True)
            scheduler.step()
            
            model_scores = args.beta * val_class_acc + (1 - args.beta) * val_class_precision
            filepath = model_save_dir + "/" + str(roundNum) + "-" + str(epoch) + "-" + str(round(validation_acc, 2)) + ".hdf5"
            for i in range(args.num_classes):
                if(model_scores[i] > model_best_scores[i]):
                    model_best_scores[i] = model_scores[i]
                    model_indexs[i] = len(modelList)

            if(validation_acc > best_val_accuracy):
                best_val_accuracy = validation_acc
                best_test_accuracy = test_acc
                best_model_name = filepath
                evaluate_share(model, share_loader, share_nosiy_labels, share_clean_labels)

            # save model
            modelList = np.append(modelList, filepath)
            torch.save(model.state_dict(), filepath)

            # write logs
            step = roundNum * args.n_epoch + epoch
            writer.add_scalar('Loss/test', test_loss, step)
            writer.add_scalar('Loss/validation', val_loss, step)
            writer.add_scalar('Loss/train', train_loss, step)
            writer.add_scalar('Accuracy/test', test_acc, step)
            writer.add_scalar('Accuracy/validation', validation_acc, step)
            writer.add_scalar('Accuracy/train', train_acc, step)
            writer.add_scalar('Best/val best', best_val_accuracy, step)
            writer.add_scalar('Best/test best', best_test_accuracy, step)
            for i in range(args.num_classes):
                writer.add_scalar('Class Validation Accuracy/Class' + str(i), val_class_acc[i], step)
                writer.add_scalar('Class Validation Precision/Class' + str(i), val_class_precision[i], step)
                writer.add_scalar('Class Test Accuracy/Class' + str(i), test_class_acc[i], step)
                writer.add_scalar('Class Test Precision/Class' + str(i), test_class_precision[i], step)

            if(roundNum == args.n_round - 1 and epoch == args.n_epoch - 1):
                break

            if((epoch + 1) % args.update_interval == 0):
                # update labels
                print(getTime(), "Model_best_scores", np.around(model_best_scores, decimals=2), np.around(np.average(model_best_scores), decimals=2), "model indexs", model_indexs, "modelList", modelList)
                train_data, train_labels = combinateModels(modelList, model_best_scores, model_indexs, predict_dataset)
                new_dataset = Clothing1M_Dataset(train_data, train_labels, args.data_root, transform, target_transform)
                new_dataset_loader = DataLoader(dataset=new_dataset, batch_size=args.batch_size, num_workers=32, shuffle=True)

                # update loss function
                train_nums = np.zeros(args.num_classes, dtype=int)
                for item in train_labels:
                    train_nums[int(item)] += 1
                print("train categroy mean", np.mean(train_nums, dtype=int), "category", train_nums, "precent", np.mean(train_nums) / train_nums)

                if torch.cuda.is_available:
                    class_weights = torch.FloatTensor(np.mean(train_nums) / train_nums * val_nums / np.mean(val_nums)).cuda()
                    loss_func = nn.CrossEntropyLoss(weight=class_weights).cuda()
                else:
                    class_weights = torch.FloatTensor(np.mean(train_nums) / train_nums * val_nums / np.mean(val_nums))
                    loss_func = nn.CrossEntropyLoss(weight=class_weights)

    print("Best_test_accuracy:", best_test_accuracy, ", best_model_name:", best_model_name)


if __name__ == '__main__':
    main()
