from __future__ import print_function

import argparse
import logging
import math
import os
import sys
import time

import tensorboard_logger as tb_logger
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets, transforms

from dataset import ExtraLabelDatasetWrapper
from dataset.attribute_data import (AttributeCIFAR10DataModule,
                                    AttributeCIFAR100DataModule)
from dataset.data import DataLoaderCreator
from calibrate.losses import HMLC, SupConLoss
from calibrate.util import (AverageMeter, ExponentialMovingAverage,
                            SaveFeaturesInputHook, SaveFeaturesListInputHook,
                            TwoCropTransform, accuracy, adjust_learning_rate,
                            save_model, set_optimizer, warmup_learning_rate)
from models.module import all_classifiers
from parse_tree.tree_builder import TreeBuilder

try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass


def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=16,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.05,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'cifar100', 'path'], help='dataset')
    parser.add_argument('--mean', type=str,
                        help='mean of dataset in path in form of str tuple')
    parser.add_argument('--std', type=str,
                        help='std of dataset in path in form of str tuple')
    parser.add_argument('--ema_decay', type=float, default=0.99,
                        help='Decay rate for exponential moving average of model weights')
    parser.add_argument('--contrast_weight', type=float, default=0.1,
                        help='Weighting for the supervised contrast')
    parser.add_argument('--data_folder', type=str,
                        default=None, help='path to custom dataset')
    parser.add_argument('--size', type=int, default=32,
                        help='parameter for RandomResizedCrop')
    parser.add_argument("--support_data_dir", type=str,
                        default="/Checkpoint/user/data/tree_data/cifar10_images/")
    parser.add_argument("--prompt_file", type=str,
                        default="cifar10/cifar10_prompt.json")

    # method
    parser.add_argument('--method', type=str, default='HSupCon',
                        choices=['SupCon', 'SimCLR', 'HSupCon'], help='choose method')

    # temperature
    parser.add_argument('--temp', type=float, default=0.07,
                        help='temperature for loss function')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--syncBN', action='store_true',
                        help='using synchronized batch normalization')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
    parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')
    parser.add_argument('--hierarchical_freq', type=int, default=10,
                        help='frequency of adding hierarchical labels to the training set, \
                    0 means no hierarchical labels')

    opt = parser.parse_args()

    # check if dataset is path that passed required arguments
    if opt.dataset == 'path':
        assert opt.data_folder is not None \
            and opt.mean is not None \
            and opt.std is not None

    # set the path according to the environment
    if opt.data_folder is None:
        opt.data_folder = './datasets/'
    opt.model_path = './save/HSupCon/{}_models'.format(opt.dataset)
    opt.tb_path = './save/HSupCon/{}_tensorboard'.format(opt.dataset)

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    # Set the name of the model and related parameters
    opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}_ema{}_epoch{}_cw{}'.\
        format(opt.method, opt.dataset, opt.model, opt.learning_rate,
               opt.weight_decay, opt.batch_size, opt.temp, opt.trial, opt.ema_decay, opt.epochs, opt.contrast_weight)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.batch_size > 256:
        opt.warm = True
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate

    opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
    if not os.path.isdir(opt.tb_folder):
        os.makedirs(opt.tb_folder)

    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)

    return opt


def set_loader(opt):
    # construct data loader
    if opt.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif opt.dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif opt.dataset == 'path':
        mean = eval(opt.mean)
        std = eval(opt.std)
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=15),  # Add random rotation
            transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Add color jitter
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(
                0.9, 1.1), shear=(-10, 10)),  # Add random affine transformations
            transforms.ToTensor(),
            normalize,
            transforms.RandomErasing(p=0.5, scale=(0.02, 0.33)),
        ]
    )

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    if opt.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=opt.data_folder,
                                         transform=TwoCropTransform(
                                             train_transform),
                                         download=True)
        val_dataset = datasets.CIFAR10(root=opt.data_folder,
                                       train=False,
                                       transform=val_transform)
    elif opt.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=opt.data_folder,
                                          transform=TwoCropTransform(
                                              train_transform),
                                          download=True)

        val_dataset = datasets.CIFAR100(root=opt.data_folder,
                                        train=False,
                                        transform=val_transform)
    elif opt.dataset == 'path':
        train_dataset = datasets.ImageFolder(root=opt.data_folder,
                                             transform=TwoCropTransform(train_transform))
    else:
        raise ValueError(opt.dataset)

    train_sampler = None

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    return train_dataset, val_loader


def set_model(opt):
    if opt.model in all_classifiers.keys():
        model = all_classifiers[opt.model]()
    else:
        model = torch.hub.load("chenyaofo/pytorch-cifar-models", opt.model)

    if opt.dataset == "cifar10":
        state_dict = os.path.join(
            "models", "state_dicts", opt.model + ".pt"
        )
        model.load_state_dict(torch.load(state_dict))
    elif opt.dataset == "cifar100":
        state_dict = os.path.join(
            "cifar100_models", "state_dicts", opt.model + ".pth"
        )
        model.load_state_dict(torch.load(state_dict)['model'])
    else:
        NotImplementedError(f"{opt.dataset} not implemented.")

    # model = torch.hub.load("chenyaofo/pytorch-cifar-models", opt.model, pretrained=True)
    hook = SaveFeaturesInputHook()
    logging.info(f"Using {opt.model}")
    if "repvgg" in opt.model:
        model.linear.register_forward_hook(hook)
    elif "resnet" in opt.model or "shufflenetv2" in opt.model:
        model.fc.register_forward_hook(hook)
    elif "vgg" in opt.model or "mobilenetv2" in opt.model or "densenet" in opt.model:
        model.classifier.register_forward_hook(hook)
    else:
        raise NotImplementedError(f"The {opt.model} is not implemented")

    # criterion = SupConLoss(temperature=opt.temp)
    criterion = HMLC(temperature=opt.temp)

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        # if torch.cuda.device_count() > 1:
        #     model = torch.nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    hook.enable()

    return model, criterion, hook


def add_hierarchical_to_train(model, opt, train_dataset, device):

    # Define the required arguments for the dataset
    if opt.dataset == 'cifar10':
        data = AttributeCIFAR10DataModule(opt, prefix=None)
    elif opt.dataset == 'cifar100':
        data = AttributeCIFAR100DataModule(opt, prefix=None)
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    # Create data loader and query loader using DataLoaderCreator class
    data_loader_creator = DataLoaderCreator(opt)
    data_loader = data_loader_creator.create_data_loader()

    query_loader = data.test_dataloader()
    # Save the features of the query loader using the SaveFeaturesListInputHook
    hook = SaveFeaturesListInputHook()
    if "repvgg" in opt.model:
        hook_handler = model.linear.register_forward_hook(hook)
    elif "resnet" in opt.model or "shufflenetv2" in opt.model:
        hook_handler = model.fc.register_forward_hook(hook)
    elif "vgg" in opt.model or "mobilenetv2" in opt.model or "densenet" in opt.model:
        hook_handler = model.classifier.register_forward_hook(hook)
    else:
        raise NotImplementedError(f"The {opt.model} is not implemented")

    categories = data.categories
    prompt_file = opt.prompt_file
    # Create an instance of the TreeBuilder class
    tree_builder = TreeBuilder(model, query_loader, hook,
                               prompt_file=prompt_file,
                               categories=categories)

    # Extract the intermediate activations from the pre-trained model using the device
    tree_builder.extract_feature(device)

    # Construct Parse_Tree objects for each category
    tree_builder.construct_tree()

    # Query hierarchical labels from the constructed trees and save to a JSON file
    all_paths = tree_builder.query_tree_label(data_loader, device)

    # Wrap the train loader dataset with extra hierarchical labels
    train_dataset = ExtraLabelDatasetWrapper(train_dataset, extra_labels=all_paths)


    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.num_workers,
        pin_memory=True
    )

    # Remove the hook
    hook_handler.remove()

    return train_loader


def train(train_loader, model, criterion, optimizer, epoch, opt, hook, ema):
    """Train the model for one epoch"""

    # Set the model to train mode
    model.train()

    # Initialize average meters to keep track of performance metrics
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # Record the start time for the epoch
    end = time.time()

    # Loop over the batches in the train loader
    for idx, (images, labels, hirearch) in enumerate(train_loader):
        # Record the time it takes to load the data
        data_time.update(time.time() - end)

        # Concatenate the two views of each image together
        images = torch.cat([images[0], images[1]], dim=0)
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            hirearch = hirearch.cuda(non_blocking=True).squeeze(1)
        bsz = labels.shape[0]

        # Warm-up the learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # Compute the supervised contrastive loss
        output = model(images)
        # features = hook.features
        features = [f.cuda(0) for f in hook.features]
        
        features = torch.cat(features, dim=0)
        features = F.normalize(features, dim=1)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        con_loss = criterion(features, hirearch)

        # Compute the overall loss as a combination of supervised cross-entropy and contrastive loss
        targets = torch.cat([labels, labels])
        loss = F.cross_entropy(output, targets) + \
            opt.contrast_weight * con_loss

        # Update the performance metrics
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, targets, topk=(1, 5))
        top1.update(acc1[0], bsz)

        # Zero out the gradients and perform SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        hook.reset()
        # Update the EMA shadow variables
        ema.update()

        # Record the time it took to process this batch
        batch_time.update(time.time() - end)
        end = time.time()

        # Log performance information every opt.print_freq batches
        if (idx + 1) % opt.print_freq == 0:
            logging.info('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, idx + 1, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    # Return the average losses and accuracy for the epoch
    return losses.avg, top1.avg


def set_ema(model, decay):
    ema = ExponentialMovingAverage(model, decay)
    ema.register()
    return ema


def validate(val_loader, model, classifier, opt):
    """validation"""
    model.eval()
    classifier.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]

            # forward
            output = model(images)
            loss = F.cross_entropy(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            top1.update(acc1[0], bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                logging.info('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          idx, len(val_loader), batch_time=batch_time,
                          loss=losses, top1=top1))

    logging.info(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return losses.avg, top1.avg


def main():
    opt = parse_option()

    # configure logging to save output to a local file
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(opt.save_folder, 'logfile.txt')),
            logging.StreamHandler()
        ]
    )


    # build data loader
    train_dataset, val_loader = set_loader(opt)

    # build model and criterion
    model, criterion, hook = set_model(opt)

    # build optimizer
    optimizer = set_optimizer(opt, model)

    ema = set_ema(model, opt.ema_decay)

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # initialize best validation accuracy
    best_val_acc = 0.0

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # add hierarchical labels every k epochs
        if epoch % opt.hierarchical_freq == 1:
            logging.info("="*10 + "Add hierarchical label to the train label" + "="*10)
            train_loader = add_hierarchical_to_train(
                model, opt, train_dataset, device='cuda:0')

        # train for one epoch
        time1 = time.time()
        loss, train_acc = train(train_loader, model,
                                criterion, optimizer, epoch, opt, hook, ema)
        time2 = time.time()
        logging.info('epoch %s, total time %.2f', epoch, time2 - time1)

        # tensorboard logger
        logger.log_value('train_loss', loss, epoch)
        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value(
            'learning_rate', optimizer.param_groups[0]['lr'], epoch)

        # Apply EMA to the model during validation
        ema.apply_shadow_weights()
        # eval for one epoch
        loss, val_acc = validate(val_loader, model, criterion, opt)
        ema.restore_original_weights()

        logger.log_value('val_loss', loss, epoch)
        logger.log_value('val_acc', val_acc, epoch)

        # check if the current model is the best so far and save it
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_file = os.path.join(opt.save_folder, 'best.pth')
            save_model(model, optimizer, opt, epoch, save_file)

        if epoch % opt.save_freq == 0:
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)

    # save the last model
    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)


if __name__ == '__main__':
    main()
