import argparse
import math
import os, sys
import random
import time
import json
import numpy as np
import pandas as pd

import torch
import torchvision
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

import _init_paths
from dataset.mixmatch_dataset import get_datasets

from models.smodel import build_model
from models.proj_norm import proj_norm, celoss

from utils.metrics import score

from utils.logger import setup_logger
from utils.meter import AverageMeter, AverageMeterHMS, ProgressMeter
from utils.helper import clean_state_dict, get_raw_dict, ModelEma
from utils.rkloss import ranking_loss


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"


def parser_args():
    parser = argparse.ArgumentParser(description='MixMatch Training')
    

    parser.add_argument('--train-iteration', type=int, default=256,
                        help='Number of iteration per epoch')
    parser.add_argument('--alpha', default=0.75, type=float)
    parser.add_argument('--lambda-u', default=75, type=float)
    parser.add_argument('--T', default=0.5, type=float)

    # data
    parser.add_argument('--dataset_name', help='dataset name', default='flickr', choices=['flickr', 'twitter', 'raf', 'emotion6', 'fbp5500'])
    parser.add_argument('--dataset_dir', help='dir of all datasets', default='./data')
    parser.add_argument('--img_size', default=256, type=int,
                        help='size of input images')
    parser.add_argument('--output', metavar='DIR', default='./outputs',
                        help='path to output folder')


    # train
    parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
                        help='number of data loading workers (default: 0)')
    parser.add_argument('--epochs', default=10, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--val_interval', default=1, type=int, metavar='N',
                        help='interval of validation')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch_size', default=32, type=int,
                        help='batch size')
    parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--wd', '--weight_decay', default=1e-2, type=float,
                        metavar='W', help='weight decay (default: 1e-2)',
                        dest='weight_decay')
    parser.add_argument('-p', '--print_freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--amp', action='store_true', default=True,
                        help='apply amp')
    parser.add_argument('--early_stop', action='store_true', default=True,
                        help='apply early stop')
    parser.add_argument('--train_ensemble', action='store_true', default=False,
                        help='apply ensemble during training')
    parser.add_argument('--train_unlabel', action='store_true', default=False,
                    help="train unlabel data")
    parser.add_argument('--proj_norm', action='store_true', default=False,
                    help="train unlabel data")


    # random seed
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')


    # model
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True,
                        help='use pre-trained model. default is True. ')
    parser.add_argument('--is_data_parallel', action='store_true', default=False,
                        help='on/off nn.DataParallel()')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--resume_omit', default=[], type=str, nargs='*')
    parser.add_argument('--ema_decay', default=0.997, type=float, metavar='M',
                        help='decay of model ema')


    args = parser.parse_args()

    args.dataset_dir = os.path.join(args.dataset_dir, args.dataset_name) 
    args.output = os.path.join(args.output, args.dataset_name, 'second')

    return args


def get_args():
    args = parser_args()
    return args


def same_seeds(seed):
    random.seed(seed) 
    np.random.seed(seed)  
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def main():
    torchvision.disable_beta_transforms_warning()
    args = get_args()

    if args.seed is not None:
        same_seeds(args.seed)

    os.makedirs(args.output, exist_ok=True)

    logger = setup_logger(output=args.output, color=False, name="LEModel")
    logger.info("Command: "+' '.join(sys.argv))

    path = os.path.join(args.output, "config.json")
    with open(path, 'w') as f:
        json.dump(get_raw_dict(args), f, indent=2)
    logger.info("Full config saved to {}".format(path))

    return main_worker(args, logger)

def main_worker(args, logger):

    # build model
    model = build_model(args)
    if args.is_data_parallel:
        model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            if 'state_dict' in checkpoint:
                state_dict = clean_state_dict(checkpoint['state_dict'])
            elif 'model' in checkpoint:
                state_dict = clean_state_dict(checkpoint['model'])
            else:
                raise ValueError("No model or state_dicr Found!!!")
            logger.info("Omitting {}".format(args.resume_omit))
            for omit_name in args.resume_omit:
                del state_dict[omit_name]
            model.load_state_dict(state_dict, strict=False)
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            del checkpoint
            del state_dict
            torch.cuda.empty_cache() 
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    ema_m = ModelEma(model, args.ema_decay) # 0.9997

    # optimizer
    # args.lr_mult = args.batch_size / 256

    # param_dicts = [
    #     {"params": [p for n, p in model.named_parameters() if p.requires_grad]},
    # ]
    # optimizer = getattr(torch.optim, 'AdamW')(
    #     param_dicts,
    #     args.lr_mult * args.lr,
    #     betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay
    # )

    # tensorboard
    summary_writer = SummaryWriter(log_dir=args.output)

    ###############################################################################
    
    # Data loading code
    train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset = get_datasets(args)
    print("len(train_label_dataset):", len(train_label_dataset)) 
    print("len(train_unlabel_dataset):", len(train_unlabel_dataset)) 
    print("len(val_dataset):", len(val_dataset))
    print("len(test_dataset):", len(test_dataset))

    args.workers = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8])  # number of workers

    labeled_trainloader = torch.utils.data.DataLoader(
        train_label_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=0, pin_memory=True, drop_last=True)

    unlabeled_trainloader = torch.utils.data.DataLoader(
        train_unlabel_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=0, drop_last=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True)
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True)

    
    epoch_time = AverageMeterHMS('TT')
    eta = AverageMeterHMS('ETA', val_only=True)
    losses = AverageMeter('loss', ':5.5f', val_only=True)
    losses_ema = AverageMeter('loss_ema', ':5.5f', val_only=True)
    progress = ProgressMeter(
        args.epochs,
        [eta, epoch_time, losses, losses_ema],
        prefix='=> Test Epoch: ')

    # one cycle learning rate
    args.steps_per_epoch = args.train_iteration
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=args.steps_per_epoch, epochs=args.epochs, pct_start=0.2)

    end = time.time()
    best_epoch = -1
    best_regular_loss = 1e10
    best_regular_epoch = -1
    best_ema_loss = 1e10
    regular_loss_list = []
    ema_loss_list = []
    loss_ema_test = 1e10
    best_loss = 1e10


    torch.cuda.empty_cache()
    for epoch in range(args.start_epoch, args.epochs):

        torch.cuda.empty_cache()

        # train for one epoch
        train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, ema_m, optimizer, scheduler, epoch, args, logger)
        # train_loss, train_loss_x, train_loss_u = 10000, 10000, 10000
        if summary_writer:
            # tensorboard logger
            summary_writer.add_scalar('losses/train_loss', train_loss, epoch)
            summary_writer.add_scalar('losses/train_loss_x', train_loss_x, epoch)
            summary_writer.add_scalar('losses/train_loss_u', train_loss_u, epoch)
            summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % args.val_interval == 0:

            # evaluate on validation set
            loss, (cheby, clark, can, kl, cosine, inter, spear, tau) = validate(val_loader, model, args, logger)
            loss_ema, (cheby_ema, clark_ema, can_ema, kl_ema, cosine_ema, inter_ema, spear_ema, tau_ema) = validate(val_loader, ema_m.module, args, logger)
            print(can, cheby, clark, cosine, inter, kl, spear, tau)
            print(can_ema, cheby_ema, clark_ema, cosine_ema, inter_ema, kl_ema, spear_ema, tau_ema)
            losses.update(loss)
            losses_ema.update(loss_ema)
            epoch_time.update(time.time() - end)
            end = time.time()
            eta.update(epoch_time.avg * (args.epochs - epoch - 1))

            regular_loss_list.append(loss)
            ema_loss_list.append(loss_ema)

            progress.display(epoch, logger)

            if summary_writer:
                # tensorboard logger
                summary_writer.add_scalar('losses/val_loss', loss, epoch)
                summary_writer.add_scalar('ema_losses/val_loss_ema', loss_ema, epoch)
                summary_writer.add_scalar('losses/val_cheby', cheby, epoch)
                summary_writer.add_scalar('ema_losses/val_cheby_ema', cheby_ema, epoch)
                summary_writer.add_scalar('losses/val_clark', clark, epoch)
                summary_writer.add_scalar('ema_losses/val_clark_ema', clark_ema, epoch)
                summary_writer.add_scalar('losses/val_canberra', can, epoch)
                summary_writer.add_scalar('ema_losses/val_canberra_ema', can_ema, epoch)
                summary_writer.add_scalar('losses/val_kl', kl, epoch)
                summary_writer.add_scalar('ema_losses/val_kl_ema', kl_ema, epoch)
                summary_writer.add_scalar('losses/val_cosine', cosine, epoch)
                summary_writer.add_scalar('ema_losses/val_cosine_ema', cosine_ema, epoch)
                summary_writer.add_scalar('losses/val_intersection', inter, epoch)
                summary_writer.add_scalar('ema_losses/val_intersection_ema', inter_ema, epoch)
                summary_writer.add_scalar('losses/val_spear', spear, epoch)
                summary_writer.add_scalar('ema_losses/val_spear_ema', spear_ema, epoch)
                summary_writer.add_scalar('losses/val_tau', tau, epoch)
                summary_writer.add_scalar('ema_losses/val_tau_ema', tau_ema, epoch)
                

            # remember best (regular) loss and corresponding epochs
            if loss < best_regular_loss:
                best_regular_loss = min(best_regular_loss, loss)
                best_regular_epoch = epoch
            if loss_ema < best_ema_loss:
                best_ema_loss = min(loss_ema, best_ema_loss)
            
            if loss_ema < loss:
                loss = loss_ema
                state_dict = ema_m.module.state_dict()
            else:
                state_dict = model.state_dict()
            is_best = loss < best_loss
            if is_best:
                best_epoch = epoch
            best_loss = min(loss, best_loss)

            if best_loss == loss_ema:
                loss_ema_test, (cheby1, clark1, can1, kl1, cosine1, inter1, spear1, tau1) = validate(test_loader, ema_m.module, args, logger)
            elif best_loss == loss:
                loss_ema_test, (cheby1, clark1, can1, kl1, cosine1, inter1, spear1, tau1) = validate(test_loader, model, args, logger)

            logger.info("{} | Set best loss {} in ep {}".format(epoch, best_loss, best_epoch))
            logger.info("   | best regular loss {} in ep {}".format(best_regular_loss, best_regular_epoch))
            logger.info("   | best test loss {} ".format(loss_ema_test))

           
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'best_loss': best_loss,
                'optimizer' : optimizer.state_dict(),
            }, is_best=is_best, filename=os.path.join(args.output, 'checkpoint.pth.tar'))

            if math.isnan(loss):
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer' : optimizer.state_dict(),
                }, is_best=is_best, filename=os.path.join(args.output, 'checkpoint_nan.pth.tar'))
                logger.info('Loss is NaN, break')
                sys.exit(1)


            # early stop
            if args.early_stop:
                if best_epoch >= 0 and epoch - max(best_epoch, best_regular_epoch) > 8:
                    if len(ema_loss_list) > 1 and ema_loss_list[-1] < best_ema_loss:
                        logger.info("epoch - best_epoch = {}, stop!".format(epoch - best_epoch))
                        break

    print("Best loss:", best_loss)

    if summary_writer:
        summary_writer.close()
    
    return 0

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    if is_best:
        torch.save(state, os.path.split(filename)[0] + '/model_best.pth.tar')
##################
def train(labeled_trainloader, unlabeled_trainloader, model, ema_m, optimizer, scheduler, epoch, args, logger):
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
    
    losses = AverageMeter('Loss', ':5.3f')
    losses_x = AverageMeter('Loss_x', ':5.3f')
    losses_u = AverageMeter('Loss_u', ':5.3f')
    lr = AverageMeter('LR', ':.3e', val_only=True)
    mem = AverageMeter('Mem', ':.0f', val_only=True)
    progress = ProgressMeter(
        args.steps_per_epoch,
        [losses, losses_x, losses_u, lr, mem],
        prefix="Epoch: [{}/{}]".format(epoch, args.epochs))

    def get_learning_rate(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    lr.update(get_learning_rate(optimizer))
    logger.info("lr:{}".format(get_learning_rate(optimizer)))

    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)

    train_criterion = SemiLoss(args)

    # switch to train mode
    model.train()
    
    for batch_idx in range(args.train_iteration):
        try:
            inputs_x, targets_x = next(labeled_train_iter)
        except:
            labeled_train_iter = iter(labeled_trainloader)
            inputs_x, targets_x = next(labeled_train_iter)

        try:
            (inputs_u, inputs_u2), _ = next(unlabeled_train_iter)
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            (inputs_u, inputs_u2), _ = next(unlabeled_train_iter)

        batch_size = inputs_x.size(0)

        inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
        inputs_u = inputs_u.cuda()
        inputs_u2 = inputs_u2.cuda()


        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=args.amp):
            # compute guessed labels of unlabel samples
                outputs_u = model(inputs_u)
                outputs_u2 = model(inputs_u2)
                p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
                pt = p**(1/args.T)
                targets_u = pt / pt.sum(dim=1, keepdim=True)
                targets_u = targets_u.detach()

        # mixup
        all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)

        l = np.random.beta(args.alpha, args.alpha)

        l = max(l, 1-l)

        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
        mixed_input = list(torch.split(mixed_input, batch_size))
        mixed_input = interleave(mixed_input, batch_size)

        logits = [model(mixed_input[0])]
        for input in mixed_input[1:]:
            logits.append(model(input))


        # put interleaved samples back
        logits = interleave(logits, batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        Lx, Lu, w = train_criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.train_iteration)

        loss = Lx + w * Lu

        # record loss
        losses.update(loss.item(), inputs_x.size(0))
        losses_x.update(Lx.item(), inputs_x.size(0))
        losses_u.update(Lu.item(), inputs_x.size(0))
        mem.update(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        ema_m.update(model)

        # one cycle learning rate
        scheduler.step()
        lr.update(get_learning_rate(optimizer))
        


        if batch_idx % args.print_freq == 0:
            progress.display(batch_idx, logger)

    return losses.avg, losses_x.avg, losses_u.avg


@torch.no_grad()
def validate(val_loader, model, args, logger):
    batch_time = AverageMeter('Time', ':5.3f')
    mem = AverageMeter('Mem', ':.0f', val_only=True)

    progress = ProgressMeter(
        len(val_loader),
        [batch_time, mem],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    loss_list = []
    cheby_list = []
    clark_list = []
    can_list = []
    kl_list = []
    cosine_list = []
    inter_list = []
    spear_list = []
    tau_list = []
        
    end = time.time()
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    for i, (X, y) in enumerate(val_loader):
        X = X.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True).float()

        # compute output
        with torch.cuda.amp.autocast(enabled=args.amp):
            y_hat = model(X)
        
        if args.proj_norm:
            y_hat = proj_norm(y_hat)

        loss = criterion(y_hat, y)


        (cheby, clark, can, kl, cosine, inter, spear, tau) = score(y, F.softmax(y_hat, dim=-1))

        # add list
        loss_list.append(loss.detach().cpu())
        cheby_list.append(cheby.detach().cpu())
        clark_list.append(clark.detach().cpu())
        can_list.append(can.detach().cpu())
        kl_list.append(kl.detach().cpu())
        cosine_list.append(cosine.detach().cpu())
        inter_list.append(inter.detach().cpu())
        spear_list.append(spear.detach().cpu())
        tau_list.append(tau.detach().cpu())

        # record memory
        mem.update(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i, logger)

    # calculate loss
    loss = sum(loss_list) / len(val_loader.dataset)
    chebyloss = sum(cheby_list) / len(val_loader.dataset)
    clarkloss = sum(clark_list) / len(val_loader.dataset)
    canloss = sum(can_list) / len(val_loader.dataset)
    klloss = sum(kl_list) / len(val_loader.dataset)
    cosineloss = sum(cosine_list) / len(val_loader.dataset)
    interloss = sum(inter_list) / len(val_loader.dataset)
    spearloss = sum(spear_list) / len(val_loader.dataset)
    tauloss = sum(tau_list) / len(val_loader.dataset)
    
    print("Calculating loss:")  
    logger.info("  loss: {}".format(loss))

    return loss, (chebyloss, clarkloss, canloss, klloss, cosineloss, interloss, spearloss, tauloss)


def linear_rampup(current, args):
    if args.epochs == 0:
        return 1.0
    else:
        current = np.clip(current / args.epochs, 0.0, 1.0)
        return float(current)


class SemiLoss(object):
    def __init__(self, args) -> None:
        self.args = args

    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        return Lx, Lu, self.args.lambda_u * linear_rampup(epoch, self.args)


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]


if __name__ == '__main__':
    main()