import torch
import os
import numpy as np
import argparse

from utils import set_seed, get_optimizer, adjust_learning_rate, AverageMeter
from dataset.dataloaders import get_test_dataloaders
from dataset.datasets import datasets_dict
from model.architecture.ResNet import model_dict
from model.URSL import LinearClassifier


def test(args):

    average_meter = AverageMeter()

    print ("Creating Model =>")
    module, _ = model_dict[args.student_arch]
    criterion = torch.nn.CrossEntropyLoss()
    encoder = module().to(args.device)
    classifier = LinearClassifier(args).to(args.device)
    encoder.eval()

    print ("Building Optimizer =>")
    optimizer = get_optimizer(classifier, args, args.lr)

    number_of_tasks = args.labeled_dataset.num_classes // args.class_per_task

    print ("Loading trained Model =>")
    model_path = os.path.join(args.directory, "encoder_" + str(number_of_tasks))
    index_path = os.path.join(args.directory, "index_" + str(number_of_tasks) + ".npy")
    state_dict = torch.load(model_path, map_location=args.device)
    encoder.load_state_dict(state_dict)
    replay_indexes = np.load(index_path)

    print(model_path)

    print ("Creating test data =>")
    train_loader, val_loader = get_test_dataloaders(args, replay_indexes)

    print ("Training Linear =>")

    for epoch in range(args.epochs):
        epoch = epoch + 1
        print('Epoch {} / {}'.format(epoch, args.epochs))

        adjust_learning_rate(args, optimizer, epoch, args.epochs)

        for idx, (images, labels) in enumerate(train_loader):

            images = images.to(args.device, non_blocking=True)
            labels = labels.to(args.device, non_blocking=True)
            bsz = labels.shape[0]

            with torch.no_grad():
                features = encoder(images).detach()

            output = classifier(features)
            loss = criterion(output, labels)

            average_meter.update(loss.item(), bsz)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(average_meter.avg)
        average_meter.reset()

    # Test trained model and classifier
    correct = 0.
    with torch.no_grad():
        for idx, (images, labels) in enumerate(val_loader):
            images = images.to(args.device, non_blocking=True)
            labels = labels.to(args.device, non_blocking=True)
            bsz = labels.shape[0]

            # forward
            features = encoder(images)
            output = classifier(features)

            correct += sum(labels == torch.argmax(output, dim = 1))

    print('Accuracy of Trained Classifier : {}%'.format(round(correct.item() / 10000 * 100, 4)))
    accuracy_file = os.path.join(args.directory, "accuracy.txt")
    with open(accuracy_file, 'w') as f:
        f.write('Accuracy of Trained Classifier : {}%'.format(round(correct.item() / 10000 * 100, 4)))


# read parsed arguments
parser = argparse.ArgumentParser(description='Commands')

# general
parser.add_argument('--seed', type=int, default = 42, help='seed for experiments')

parser.add_argument('--experiment-name', type=str, default = 'URSL', help='experiment name for training')

# dataset
parser.add_argument('--labeled-dataset-name', type=str, default = 'cifar10', help='main dataset name')

parser.add_argument('--unlabeled-dataset-name', type=str, default = 'cifar100', help='peripheral dataset name')

parser.add_argument('--batch-size', type=int, default = 512, help='batch size for labeled samples')

parser.add_argument('--unlabeled-batch-size', type=int, default = 512, help='batch size for unlabeled samples')

parser.add_argument('--num-workers', type=int, default = 4, help='number of workers for dataloaders')

parser.add_argument('--num-main-unlabeled', type=int, default = 9000, help='number of unlabeled data for each task from main dataset')

parser.add_argument('--num-peripheral-unlabeled', type=int, default = 9000, help='number of unlabeled data for each task from peripheral dataset')

parser.add_argument('--num-labeled-per-class', type=int, default = 25, help='number of labeled data in tasks for each class')


# model
parser.add_argument('--epochs', type=int, default = 100, help='number of student classifier training epochs')

parser.add_argument('--start-warmup-epochs', type=int, default = 400, help='number of teacher network training epochs for first task')

parser.add_argument('--warmup-epochs', type=int, default = 100, help='number of teacher network training epochs for next tasks')

parser.add_argument('--class-per-task', type=int, default = 10, help='number of classes per task')

parser.add_argument('--memory-size', type=int, default = 500, help='number of memory buffer size')

parser.add_argument('--teacher-arch', type=str, default = 'resnet18', help='architecture for teacher model')

parser.add_argument('--student-arch', type=str, default = 'resnet18', help='architecture for student model')

parser.add_argument('--head', type=str, default = 'mlp', help='type of classifier on encoder head')


# loss
parser.add_argument('--temperature', type=float, default = 0.1, help='temperature for labeled loss')

parser.add_argument('--unlabeled-temperature', type=float, default = 0.1, help='temperature for unlabeled loss')

parser.add_argument('--current-model-temp', type=float, default = 0.2, help='temperature of irdloss for current model')

parser.add_argument('--prev-model-temp', type=float, default = 0.01, help='temperature of irdloss for previous model')

parser.add_argument('--td-coeff', type=float, default = 0.2, help='coefficient for time distillation loss')

parser.add_argument('--kd-coeff', type=float, default = 1, help='coefficient for teacher distillation loss')


# optimizer

parser.add_argument('--optimizer-name', type=str, default = 'adam', help='coefficient for time distillation loss')

parser.add_argument('--lr', type=float, default = 0.003, help='initial learning rate')

parser.add_argument('--lr-decay-rate', type=float, default = 0.1, help='learning rate decay rate')

parser.add_argument('--weight-decay', type=float, default = 1.0e-4, help='weight decay coeff')

parser.add_argument('--momentum', type=float, default = 0.9, help='optimizer momentum')

parser.add_argument('--lr-cosine-decay', default=True, action='store_true', help='using cosine decay for optimizer learninig rate')


args = parser.parse_args()
args.directory = "../experiments/" + args.labeled_dataset_name + "/" + args.experiment_name + "/"
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.labeled_dataset = datasets_dict[args.labeled_dataset_name]()
args.unlabeled_dataset = datasets_dict[args.unlabeled_dataset_name]()
args.labeled_dataset.num_labeled_per_class = args.num_labeled_per_class

set_seed(args.seed)
print(args.experiment_name)

test(args)