import torch
import os
import copy
import argparse

from model.URSL import ConResNet
from model.ood_detection import get_ood_loader
from losses import URSLLoss, NT_Xent, IRDLoss
from utils import set_seed, get_optimizer, save_model, initialize_average_meters, update_average_meters, update_tensorboard, adjust_learning_rate
from dataset.dataloaders import set_replay_samples, get_each_task_dataloader, get_unlabeled_data_loader
from dataset.datasets import datasets_dict
from torch.utils.tensorboard import SummaryWriter



def main(args):

    print ("Creating Model =>")
    teacher_model = ConResNet(args, is_teacher = True).to(args.device)
    model = ConResNet(args).to(args.device)
    ursl_loss = URSLLoss().to(args.device)
    unsup_con_loss = NT_Xent().to(args.device)
    ird_loss = IRDLoss().to(args.device)

    print ("Building Optimizer =>")
    teacher_optimizer = get_optimizer(teacher_model, args, args.lr)
    optimizer = get_optimizer(model, args, args.lr)

    if not os.path.exists(args.directory):
        os.makedirs(args.directory)

    #Report
    writer = SummaryWriter('./runs/' + args.experiment_name + '/')
    average_meters = initialize_average_meters()

    print('Starting Training')

    number_of_tasks = args.labeled_dataset.num_classes // args.class_per_task
    replay_indexes = None

    for task_number in range(number_of_tasks):
        previous_model = copy.deepcopy(model).to(args.device)
        previous_model.eval()

        replay_indexes = set_replay_samples(args, task_number, prev_indexes = replay_indexes)
        train_loader1, labeled_dataset, subset_indexes = get_each_task_dataloader(args, task_number, replay_indexes)
        train_loader2, mixed_unlabeled_dataset = get_unlabeled_data_loader(args, subset_indexes)
        
        teacher_model.train()
        
        task_epoch = args.start_warmup_epochs if task_number == 0 else args.warmup_epochs
        for epoch in range(task_epoch):
            epoch = epoch + 1
            if epoch % 10 == 0:
                print('Epoch {} / {}'.format(epoch, task_epoch))

            adjust_learning_rate(args, teacher_optimizer, epoch, task_epoch)

            for ind, (images, labels) in enumerate(train_loader2):

                teacher_optimizer.zero_grad()
                images = torch.cat([images[0], images[1]], dim=0)
                images = images.to(args.device, non_blocking=True)

                features = teacher_model(images)
                warmup_loss = unsup_con_loss(args, features)
                warmup_loss.backward()

                warmup_loss = {
                    'sup_con_loss' : 0,
                    'ird_loss' : 0,
                    'unsup_con_loss' : warmup_loss,
                    'total' : warmup_loss
                }

                teacher_optimizer.step()

                update_average_meters(average_meters, warmup_loss)

            update_tensorboard(writer, average_meters, teacher_optimizer, task_number, epoch, 'warmup')

        teacher_model.eval()
        
        train_loader3, train_loader4 = get_ood_loader(args, teacher_model, labeled_dataset, mixed_unlabeled_dataset, task_number)
        
        current_task_classes = list(range(task_number * args.class_per_task, (task_number + 1) * args.class_per_task))

        for epoch in range(args.epochs):
            epoch = epoch + 1
            if epoch % 10 == 0:
                print('Epoch {} / {}'.format(epoch, args.epochs))

            adjust_learning_rate(args, optimizer, epoch, args.epochs)

            for images, labels in train_loader3:

                optimizer.zero_grad()

                images = torch.cat([images[0], images[1]], dim=0)
                images = images.to(args.device, non_blocking=True)
                labels = labels.to(args.device, non_blocking=True)

                labeled_features = model(images)
                labeled_losses = ursl_loss(args, images, labeled_features, labels, previous_model, current_task_classes, task_number)
                labeled_losses['total'].backward()
                
                images, _ = train_loader4.__next__()
                images = torch.cat([images[0], images[1]], dim=0)
                images = images.to(args.device, non_blocking=True)
                
                unlabeled_features = model(images)
                kd_loss = ird_loss(args, unlabeled_features, teacher_model, images) * args.kd_coeff
                kd_loss.backward()
                    
                losses = {
                    'sup_con_loss' : labeled_losses['sup_con_loss'],
                    'ird_loss' : labeled_losses['ird_loss'],
                    'unsup_con_loss' : kd_loss,
                    'total' : labeled_losses['total'] + kd_loss
                }

                optimizer.step()

                update_average_meters(average_meters, losses)

            update_tensorboard(writer, average_meters, optimizer, task_number, epoch, 'task')
            
        save_model(model, subset_indexes, args.directory, task_number)








# read parsed arguments
parser = argparse.ArgumentParser(description='Commands')

# general
parser.add_argument('--seed', type=int, default = 42, help='seed for experiments')

parser.add_argument('--experiment-name', type=str, default = 'URSL', help='experiment name for training')

# dataset
parser.add_argument('--labeled-dataset-name', type=str, default = 'cifar10', help='main dataset name')

parser.add_argument('--unlabeled-dataset-name', type=str, default = 'cifar100', help='peripheral dataset name')

parser.add_argument('--batch-size', type=int, default = 512, help='batch size for labeled samples')

parser.add_argument('--unlabeled-batch-size', type=int, default = 512, help='batch size for unlabeled samples')

parser.add_argument('--num-workers', type=int, default = 4, help='number of workers for dataloaders')

parser.add_argument('--num-main-unlabeled', type=int, default = 9000, help='number of unlabeled data for each task from main dataset')

parser.add_argument('--num-peripheral-unlabeled', type=int, default = 9000, help='number of unlabeled data for each task from peripheral dataset')

parser.add_argument('--num-labeled-per-class', type=int, default = 25, help='number of labeled data in tasks for each class')

# model
parser.add_argument('--epochs', type=int, default = 200, help='number of student network training epochs')

parser.add_argument('--start-warmup-epochs', type=int, default = 400, help='number of teacher network training epochs for first task')

parser.add_argument('--warmup-epochs', type=int, default = 100, help='number of teacher network training epochs for next tasks')

parser.add_argument('--class-per-task', type=int, default = 10, help='number of classes per task')

parser.add_argument('--memory-size', type=int, default = 500, help='number of memory buffer size')

parser.add_argument('--teacher-arch', type=str, default = 'resnet18', help='architecture for teacher model')

parser.add_argument('--student-arch', type=str, default = 'resnet18', help='architecture for student model')

parser.add_argument('--head', type=str, default = 'mlp', help='type of classifier on encoder head')

# OOD

parser.add_argument('--ood-pl-var-coeff', type=int, default = -2, help='ood pseudo label variance coefficient')

parser.add_argument('--ood-in-dist-var-coeff', type=int, default = -4, help='ood in distribution variance coefficient')

# loss
parser.add_argument('--temperature', type=float, default = 0.1, help='temperature for labeled loss')

parser.add_argument('--unlabeled-temperature', type=float, default = 0.1, help='temperature for unlabeled loss')

parser.add_argument('--current-model-temp', type=float, default = 0.2, help='temperature of irdloss for current model')

parser.add_argument('--prev-model-temp', type=float, default = 0.01, help='temperature of irdloss for previous model')

parser.add_argument('--td-coeff', type=float, default = 0.2, help='coefficient for time distillation loss')

parser.add_argument('--kd-coeff', type=float, default = 0.2, help='coefficient for teacher distillation loss')


# optimizer

parser.add_argument('--optimizer-name', type=str, default = 'adam', help='coefficient for time distillation loss')

parser.add_argument('--lr', type=float, default = 0.01, help='initial learning rate')

parser.add_argument('--lr-decay-rate', type=float, default = 0.1, help='learning rate decay rate')

parser.add_argument('--weight-decay', type=float, default = 1.0e-4, help='weight decay coeff')

parser.add_argument('--momentum', type=float, default = 0.9, help='optimizer momentum')

parser.add_argument('--lr-cosine-decay', default=True, action='store_true', help='using cosine decay for optimizer learninig rate')


args = parser.parse_args()
args.directory = "../experiments/" + args.labeled_dataset_name + "/" + args.experiment_name + "/"
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.labeled_dataset = datasets_dict[args.labeled_dataset_name]()
args.unlabeled_dataset = datasets_dict[args.unlabeled_dataset_name]()
args.labeled_dataset.num_labeled_per_class = args.num_labeled_per_class

set_seed(args.seed)
print(args.experiment_name)

main(args)