import argparse
import math
import os, sys
import random
import time
import json
import numpy as np
import pandas as pd

import torch
import torchvision
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torch.optim
import torch.utils.data

from torch.utils.tensorboard import SummaryWriter

import _init_paths
from dataset.new_dataset import get_datasets

from models.smodel import build_model
from models.proj_norm import proj_norm, celoss

from utils.metrics import score

from utils.logger import setup_logger
from utils.meter import AverageMeter, AverageMeterHMS, ProgressMeter
from utils.helper import clean_state_dict, get_raw_dict, ModelEma
from utils.rkloss import ranking_loss


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"


def parser_args():
    parser = argparse.ArgumentParser(description='First Training')
    

    # data
    parser.add_argument('--dataset_name', help='dataset name', default='flickr', choices=['flickr', 'twitter', 'raf', 'emotion6', 'fbp5500'])
    parser.add_argument('--dataset_dir', help='dir of all datasets', default='./data')
    parser.add_argument('--img_size', default=256, type=int,
                        help='size of input images')
    parser.add_argument('--output', metavar='DIR', default='./outputs',
                        help='path to output folder')

    # loss
    parser.add_argument('--lambda_rk', default=0, type=float,
                        help='cofficient of Ranking loss')

    # train
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')
    parser.add_argument('--epochs', default=20, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--val_interval', default=1, type=int, metavar='N',
                        help='interval of validation')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch_size', default=32, type=int,
                        help='batch size')
    parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--wd', '--weight_decay', default=1e-2, type=float,
                        metavar='W', help='weight decay (default: 1e-2)',
                        dest='weight_decay')
    parser.add_argument('-p', '--print_freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--amp', action='store_true', default=True,
                        help='apply amp')
    parser.add_argument('--early_stop', action='store_true', default=True,
                        help='apply early stop')
    parser.add_argument('--train_ensemble', action='store_true', default=False,
                        help='apply ensemble during training')
    parser.add_argument('--train_unlabel', action='store_true', default=False,
                        help="train unlabel data")
    parser.add_argument('--proj_norm', action='store_true', default=False,
                    help="train unlabel data")


    # random seed
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')


    # model
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True,
                        help='use pre-trained model. default is True. ')
    parser.add_argument('--is_data_parallel', action='store_true', default=False,
                        help='on/off nn.DataParallel()')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--resume_omit', default=[], type=str, nargs='*')
    parser.add_argument('--ema_decay', default=0.997, type=float, metavar='M',
                        help='decay of model ema')


    args = parser.parse_args()

    args.dataset_dir = os.path.join(args.dataset_dir, args.dataset_name) 
    args.output = os.path.join(args.output, args.dataset_name, 'first')

    return args


def get_args():
    args = parser_args()
    return args


def same_seeds(seed):
    random.seed(seed) 
    np.random.seed(seed)  
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def main():
    torchvision.disable_beta_transforms_warning()
    args = get_args()
    if args.proj_norm:
        print('use proj_norm')

    if args.seed is not None:
        same_seeds(args.seed)

    os.makedirs(args.output, exist_ok=True)

    logger = setup_logger(output=args.output, color=False, name="LEModel")
    logger.info("Command: "+' '.join(sys.argv))

    path = os.path.join(args.output, "config.json")
    with open(path, 'w') as f:
        json.dump(get_raw_dict(args), f, indent=2)
    logger.info("Full config saved to {}".format(path))

    return main_worker(args, logger)

def main_worker(args, logger):

    # build model
    pi = compute_pi(args)
    model = build_model(args)
    if args.is_data_parallel:
        model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            if 'state_dict' in checkpoint:
                state_dict = clean_state_dict(checkpoint['state_dict'])
            elif 'model' in checkpoint:
                state_dict = clean_state_dict(checkpoint['model'])
            else:
                raise ValueError("No model or state_dicr Found!!!")
            logger.info("Omitting {}".format(args.resume_omit))
            for omit_name in args.resume_omit:
                del state_dict[omit_name]
            model.load_state_dict(state_dict, strict=False)
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            del checkpoint
            del state_dict
            torch.cuda.empty_cache() 
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))

    ema_m = ModelEma(model, args.ema_decay) # 0.9997

    # optimizer
    args.lr_mult = args.batch_size / 256

    param_dicts = [
        {"params": [p for n, p in model.named_parameters() if p.requires_grad]},
    ]
    optimizer = getattr(torch.optim, 'AdamW')(
        param_dicts,
        args.lr_mult * args.lr,
        betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay
    )

    # tensorboard
    summary_writer = SummaryWriter(log_dir=args.output)


    # Data loading code
    train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset, _ = get_datasets(args)
    print("len(train_label_dataset):", len(train_label_dataset)) 
    print("len(train_unlabel_dataset):", len(train_unlabel_dataset)) 
    print("len(val_dataset):", len(val_dataset))
    print("len(test_dataset):", len(test_dataset))

    args.workers = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8])  # number of workers
    

    train_loader = torch.utils.data.DataLoader(
        train_label_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)


    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    
    epoch_time = AverageMeterHMS('TT')
    eta = AverageMeterHMS('ETA', val_only=True)
    vlosses = AverageMeter('vloss', ':5.5f', val_only=True)
    vlosses_ema = AverageMeter('vloss_ema', ':5.5f', val_only=True)
    progress = ProgressMeter(
        args.epochs,
        [eta, epoch_time, vlosses, vlosses_ema],
        prefix='=> Test Epoch: ')

    # one cycle learning rate
    args.steps_per_epoch = len(train_loader)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=args.steps_per_epoch, epochs=args.epochs, pct_start=0.2)


    end = time.time()
    best_epoch = -1
    best_regular_vloss = 1e10
    best_regular_epoch = -1
    best_ema_vloss = 1e10
    regular_vloss_list = []
    ema_vloss_list = []
    vloss_ema_test = 1e10
    best_vloss = 1e10


    torch.cuda.empty_cache()
    for epoch in range(args.start_epoch, args.epochs):

        torch.cuda.empty_cache()

        # train for one epoch
        loss = train(train_loader, model, ema_m, optimizer, scheduler, epoch, args, logger, pi)

        if summary_writer:
            # tensorboard logger
            summary_writer.add_scalar('train_loss', loss, epoch)
            summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % args.val_interval == 0:

            # evaluate on validation set
            vloss, (cheby, clark, can, kl, cosine, inter, spear, tau) = validate(val_loader, model, args, logger, pi)
            vloss_ema, (cheby_ema, clark_ema, can_ema, kl_ema, cosine_ema, inter_ema, spear_ema, tau_ema) = validate(val_loader, ema_m.module, args, logger, pi)
            print(cheby, clark, can, kl, cosine, inter, spear, tau)
            vlosses.update(vloss)
            vlosses_ema.update(vloss_ema)
            epoch_time.update(time.time() - end)
            end = time.time()
            eta.update(epoch_time.avg * (args.epochs - epoch - 1))

            regular_vloss_list.append(vloss)
            ema_vloss_list.append(vloss_ema)

            progress.display(epoch, logger)

            if summary_writer:
                # tensorboard logger
                summary_writer.add_scalar('val_vloss', vloss, epoch)
                summary_writer.add_scalar('val_vloss_ema', vloss_ema, epoch)
                summary_writer.add_scalar('val_cheby', cheby, epoch)
                summary_writer.add_scalar('val_cheby_ema', cheby_ema, epoch)
                summary_writer.add_scalar('val_clark', clark, epoch)
                summary_writer.add_scalar('val_clark_ema', clark_ema, epoch)
                summary_writer.add_scalar('val_canberra', can, epoch)
                summary_writer.add_scalar('val_canberra_ema', can_ema, epoch)
                summary_writer.add_scalar('val_kl', kl, epoch)
                summary_writer.add_scalar('val_kl_ema', kl_ema, epoch)
                summary_writer.add_scalar('val_cosine', cosine, epoch)
                summary_writer.add_scalar('val_cosine_ema', cosine_ema, epoch)
                summary_writer.add_scalar('val_intersection', inter, epoch)
                summary_writer.add_scalar('val_intersection_ema', inter_ema, epoch)
                summary_writer.add_scalar('val_spear', spear, epoch)
                summary_writer.add_scalar('val_spear_ema', spear_ema, epoch)
                summary_writer.add_scalar('val_tau', tau, epoch)
                summary_writer.add_scalar('val_tau_ema', tau_ema, epoch)

            # remember best (regular) vloss and corresponding epochs
            if vloss < best_regular_vloss:
                best_regular_vloss = min(best_regular_vloss, vloss)
                best_regular_epoch = epoch
            if vloss_ema < best_ema_vloss:
                best_ema_vloss = min(vloss_ema, best_ema_vloss)
            
            if vloss_ema < vloss:
                vloss = vloss_ema
                state_dict = ema_m.module.state_dict()
            else:
                state_dict = model.state_dict()
            is_best = vloss < best_vloss
            if is_best:
                best_epoch = epoch
            best_vloss = min(vloss, best_vloss)

            if best_vloss == vloss_ema:
                vloss_ema_test, (cheby1, clark1, can1, kl1, cosine1, inter1, spear1, tau1) = validate(test_loader, ema_m.module, args, logger, pi)
            elif best_vloss == vloss:
                vloss_ema_test, (cheby1, clark1, can1, kl1, cosine1, inter1, spear1, tau1) = validate(test_loader, model, args, logger, pi)

            logger.info("{} | Set best vloss {} in ep {}".format(epoch, best_vloss, best_epoch))
            logger.info("   | best regular vloss {} in ep {}".format(best_regular_vloss, best_regular_epoch))
            logger.info("   | best test vloss {} ".format(vloss_ema_test))
            logger.info("   | best test metrics: {}".format(cheby1, clark1, can1, kl1, cosine1, inter1, spear1, tau1))

           
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'best_CEloss': best_vloss,
                'optimizer' : optimizer.state_dict(),
            }, is_best=is_best, filename=os.path.join(args.output, 'checkpoint.pth.tar'))

            if math.isnan(loss):
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_CEloss': best_vloss,
                    'optimizer' : optimizer.state_dict(),
                }, is_best=is_best, filename=os.path.join(args.output, 'checkpoint_nan.pth.tar'))
                logger.info('Loss is NaN, break')
                sys.exit(1)


            # early stop
            if args.early_stop:
                if best_epoch >= 0 and epoch - max(best_epoch, best_regular_epoch) > 8:
                    if len(ema_vloss_list) > 1 and ema_vloss_list[-1] < best_ema_vloss:
                        logger.info("epoch - best_epoch = {}, stop!".format(epoch - best_epoch))
                        break

    print("Best vloss:", best_vloss)

    if summary_writer:
        summary_writer.close()
    
    return 0

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    if is_best:
        torch.save(state, os.path.split(filename)[0] + '/pretrain_model_best.pth.tar')
##################
def train(train_loader, model, ema_m, optimizer, scheduler, epoch, args, logger, pi):
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
    
    loss_ces = AverageMeter('L_ce', ':5.3f')
    loss_rks = AverageMeter('L_rk', ':5.3f')
    losses = AverageMeter('Loss', ':5.3f')
    lr = AverageMeter('LR', ':.3e', val_only=True)
    mem = AverageMeter('Mem', ':.0f', val_only=True)
    progress = ProgressMeter(
        args.steps_per_epoch,
        [loss_ces, loss_rks, lr, losses, mem],
        prefix="Epoch: [{}/{}]".format(epoch, args.epochs))

    def get_learning_rate(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    lr.update(get_learning_rate(optimizer))
    logger.info("lr:{}".format(get_learning_rate(optimizer)))

    # loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
    # loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
    loss_fn = torch.nn.L1Loss(reduction='mean')
    # loss_fn = torch.nn.MSELoss(reduction='mean')

    # switch to train mode
    model.train()

    for i, ((X_w, X_s), y) in enumerate(train_loader):

        # **********************************************compute loss*************************************************************

        batch_size = X_w.size(0)

        input = X_w.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True).float()
        # y = F.softmax(y, dim=-1)
        # mixed precision ---- compute outputs
        with torch.cuda.amp.autocast(enabled=args.amp):
            pred = model(input)
            if args.proj_norm:
                pred = proj_norm(pred)
            loss_ce = loss_fn(F.softmax(pred, dim=-1), y)
            # loss_kl = loss_fn(F.softmax(y_mean, dim=-1).log(), y)
            loss_rk = ranking_loss(F.softmax(pred, dim=-1), y, pi)

            loss = loss_ce  + args.lambda_rk * loss_rk
            # loss = loss_rk



        # *********************************************************************************************************************

        # record loss
        loss_ces.update(loss_ce.item(), X_w.size(0))
        loss_rks.update(args.lambda_rk * loss_rk.item(), X_w.size(0))
        losses.update(loss.item(), X_w.size(0))
        mem.update(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        ema_m.update(model)

        # one cycle learning rate
        scheduler.step()
        lr.update(get_learning_rate(optimizer))
        


        if i % args.print_freq == 0:
            progress.display(i, logger)

    return losses.avg


@torch.no_grad()
def validate(val_loader, model, args, logger, pi):
    batch_time = AverageMeter('Time', ':5.3f')
    mem = AverageMeter('Mem', ':.0f', val_only=True)

    progress = ProgressMeter(
        len(val_loader),
        [batch_time, mem],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    loss_list = []
    cheby_list = []
    clark_list = []
    can_list = []
    kl_list = []
    cosine_list = []
    inter_list = []
    spear_list = []
    tau_list = []
        
    end = time.time()
    # loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
    # loss_fn = torch.nn.KLDivLoss(reduction='sum')
    loss_fn = torch.nn.L1Loss(reduction='sum')
    # loss_fn = torch.nn.MSELoss(reduction='sum')
    for i, (X, y) in enumerate(val_loader):
        X = X.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True).float()
        # y = F.softmax(y, dim=-1)

        # compute output
        with torch.cuda.amp.autocast(enabled=args.amp):
            y_hat = model(X)

        if args.proj_norm:
            y_hat = proj_norm(y_hat)
        loss_ce = loss_fn(F.softmax(y_hat, dim=-1), y)
        # loss_ce = loss_fn(F.softmax(y_hat, dim=-1).log(), y)

        loss_rk = ranking_loss(F.softmax(y_hat, dim=-1), y, pi)
        # loss = loss_ce  + args.lambda_rk * loss_rk * y.shape[0]
        loss = loss_rk

        # print(F.softmax(y_hat, dim=-1))
        # print(y)
        (cheby, clark, can, kl, cosine, inter, spear, tau) = score(y, F.softmax(y_hat, dim=-1))

        # add list
        loss_list.append(loss.detach().cpu())
        cheby_list.append(cheby.detach().cpu())
        clark_list.append(clark.detach().cpu())
        can_list.append(can.detach().cpu())
        kl_list.append(kl.detach().cpu())
        cosine_list.append(cosine.detach().cpu())
        inter_list.append(inter.detach().cpu())
        spear_list.append(spear.detach().cpu())
        tau_list.append(tau.detach().cpu())

        # record memory
        mem.update(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i, logger)

    # calculate CEloss
    loss = sum(loss_list) / len(val_loader.dataset)
    chebyloss = sum(cheby_list) / len(val_loader.dataset)
    clarkloss = sum(clark_list) / len(val_loader.dataset)
    canloss = sum(can_list) / len(val_loader.dataset)
    klloss = sum(kl_list) / len(val_loader.dataset)
    cosineloss = sum(cosine_list) / len(val_loader.dataset)
    interloss = sum(inter_list) / len(val_loader.dataset)
    spearloss = sum(spear_list) / len(val_loader.dataset)
    tauloss = sum(tau_list) / len(val_loader.dataset)
    
    print("Calculating loss:")  
    logger.info("  loss: {}".format(loss))

    return loss, (chebyloss, clarkloss, canloss, klloss, cosineloss, interloss, spearloss, tauloss)


def compute_pi(args):
    label_data = pd.read_csv(os.path.join(args.dataset_dir, 'train_label_data.csv'))
    label = label_data.iloc[:, 1:].values
    pi = torch.zeros((label.shape[1], label.shape[1]))
    for i in range(label.shape[1]):
        for j in range(label.shape[1]):
            pi[i, j] = sum(label[:, i] > label[:, j]) / label.shape[0]
    pi = torch.clip(pi, 2e-16, 1)
    return pi


if __name__ == '__main__':
    main()