import argparse
import math
import os, sys
import random
import time
import json
import numpy as np

import torch
import torchvision
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torch.optim
import torch.utils.data

from torch.utils.tensorboard import SummaryWriter

import _init_paths
from dataset.CPS_dataset import get_datasets

from models.smodel import build_model

from utils.metrics import score

from utils.logger import setup_logger
from utils.meter import AverageMeter, AverageMeterHMS, ProgressMeter
from utils.helper import clean_state_dict, get_raw_dict, ModelEma


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"


def parser_args():
    parser = argparse.ArgumentParser(description='Pre Training for Cross Pseudo Supervision')
    

    parser.add_argument('--train-iteration', type=int, default=256,
                        help='Number of iteration per epoch')
    parser.add_argument('--lamda', default=0.1, type=float,
                        help='cofficient of unlabel loss')

    # data
    parser.add_argument('--dataset_name', help='dataset name', default='flickr', choices=['flickr', 'twitter', 'raf', 'emotion6', 'fbp5500'])
    parser.add_argument('--dataset_dir', help='dir of all datasets', default='./data')
    parser.add_argument('--img_size', default=256, type=int,
                        help='size of input images')
    parser.add_argument('--output', metavar='DIR', default='./outputs',
                        help='path to output folder')


    # train
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')
    parser.add_argument('--epochs', default=20, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--val_interval', default=1, type=int, metavar='N',
                        help='interval of validation')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch_size', default=32, type=int,
                        help='batch size')
    parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--wd', '--weight_decay', default=1e-2, type=float,
                        metavar='W', help='weight decay (default: 1e-2)',
                        dest='weight_decay')
    parser.add_argument('-p', '--print_freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--amp', action='store_true', default=True,
                        help='apply amp')


    # random seed
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')


    # model
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True,
                        help='use pre-trained model. default is True. ')
    parser.add_argument('--is_data_parallel', action='store_true', default=False,
                        help='on/off nn.DataParallel()')
    parser.add_argument('--resume1', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--resume2', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--resume_omit', default=[], type=str, nargs='*')

    args = parser.parse_args()

    args.dataset_dir = os.path.join(args.dataset_dir, args.dataset_name) 
    args.output = os.path.join(args.output, args.dataset_name, 'first')

    return args


def get_args():
    args = parser_args()
    return args


def same_seeds(seed):
    random.seed(seed) 
    np.random.seed(seed)  
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def main():
    args = get_args()

    if args.seed is not None:
        same_seeds(args.seed)

    os.makedirs(args.output, exist_ok=True)

    logger = setup_logger(output=args.output, color=False, name="LEModel")
    logger.info("Command: "+' '.join(sys.argv))

    path = os.path.join(args.output, "config.json")
    with open(path, 'w') as f:
        json.dump(get_raw_dict(args), f, indent=2)
    logger.info("Full config saved to {}".format(path))

    return main_worker(args, logger)

def main_worker(args, logger):

    # build model
    model1 = build_model(args)
    model2 = build_model(args, True)
    if args.is_data_parallel:
        model1 = torch.nn.DataParallel(model1, device_ids=[0, 1, 2, 3])
        model2 = torch.nn.DataParallel(model2, device_ids=[0, 1, 2, 3])
    model1 = model1.cuda()
    model2 = model2.cuda()

    # optionally resume from a checkpoint
    if args.resume1:
        if os.path.isfile(args.resume1):
            logger.info("=> loading checkpoint '{}'".format(args.resume1))
            checkpoint = torch.load(args.resume1)

            if 'state_dict' in checkpoint:
                state_dict = clean_state_dict(checkpoint['state_dict'])
            elif 'model' in checkpoint:
                state_dict = clean_state_dict(checkpoint['model'])
            else:
                raise ValueError("No model or state_dicr Found!!!")
            logger.info("Omitting {}".format(args.resume_omit))
            for omit_name in args.resume_omit:
                del state_dict[omit_name]
            model1.load_state_dict(state_dict, strict=False)
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume1, checkpoint['epoch']))
            del checkpoint
            del state_dict
            torch.cuda.empty_cache() 
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume1))
    if args.resume2:
        if os.path.isfile(args.resume2):
            logger.info("=> loading checkpoint '{}'".format(args.resume2))
            checkpoint = torch.load(args.resume2)

            if 'state_dict' in checkpoint:
                state_dict = clean_state_dict(checkpoint['state_dict'])
            elif 'model' in checkpoint:
                state_dict = clean_state_dict(checkpoint['model'])
            else:
                raise ValueError("No model or state_dicr Found!!!")
            logger.info("Omitting {}".format(args.resume_omit))
            for omit_name in args.resume_omit:
                del state_dict[omit_name]
            model2.load_state_dict(state_dict, strict=False)
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume2, checkpoint['epoch']))
            del checkpoint
            del state_dict
            torch.cuda.empty_cache() 
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume2))

    # optimizer
    args.lr_mult = args.batch_size / 256

    param_dicts = [
        {"params": [p for n, p in model1.named_parameters() if p.requires_grad]},
    ]
    param_dicts[0]['params'].extend(p for n, p in model2.named_parameters() if p.requires_grad)
    optimizer = getattr(torch.optim, 'AdamW')(
        param_dicts,
        args.lr_mult * args.lr,
        betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay
    )

    # tensorboard
    summary_writer = SummaryWriter(log_dir=args.output)


    # Data loading code
    train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset = get_datasets(args)
    print("len(train_label_dataset):", len(train_label_dataset)) 
    print("len(train_unlabel_dataset):", len(train_unlabel_dataset)) 
    print("len(val_dataset):", len(val_dataset))
    print("len(test_dataset):", len(test_dataset))

    args.workers = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8])  # number of workers
    

    train_label_loader = torch.utils.data.DataLoader(
        train_label_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    train_unlabel_loader = torch.utils.data.DataLoader(
        train_unlabel_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    
    epoch_time = AverageMeterHMS('TT')
    eta = AverageMeterHMS('ETA', val_only=True)
    vlosses = AverageMeter('vloss', ':5.5f', val_only=True)
    progress = ProgressMeter(
        args.epochs,
        [eta, epoch_time, vlosses],
        prefix='=> Test Epoch: ')

    # one cycle learning rate
    args.steps_per_epoch = len(train_unlabel_loader) + len(train_label_loader)
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=args.steps_per_epoch, epochs=args.epochs, pct_start=0.2)


    end = time.time()
    best_epoch = -1
    best_regular_vloss = 1e10
    best_regular_epoch = -1
    regular_vloss_list = []
    best_vloss = 1e10


    torch.cuda.empty_cache()
    for epoch in range(args.start_epoch, args.epochs):

        torch.cuda.empty_cache()

        # train for one epoch
        loss = train(train_label_loader, train_unlabel_loader, model1, model2, optimizer, scheduler, epoch, args, logger)
        # loss = 10000

        if summary_writer:
            # tensorboard logger
            summary_writer.add_scalar('train_loss', loss, epoch)
            summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % args.val_interval == 0:

            # evaluate on validation set
            vloss, (cheby, clark, can, kl, cosine, inter, spear, tau) = validate(val_loader, model1, model2, args, logger)
            print(cheby, clark, can, kl, cosine, inter, spear, tau)
            vlosses.update(vloss)
            epoch_time.update(time.time() - end)
            end = time.time()
            eta.update(epoch_time.avg * (args.epochs - epoch - 1))

            regular_vloss_list.append(vloss)

            progress.display(epoch, logger)

            if summary_writer:
                # tensorboard logger
                summary_writer.add_scalar('val_vloss', vloss, epoch)
                summary_writer.add_scalar('val_cheby', cheby, epoch)
                summary_writer.add_scalar('val_clark', clark, epoch)
                summary_writer.add_scalar('val_canberra', can, epoch)
                summary_writer.add_scalar('val_kl', kl, epoch)
                summary_writer.add_scalar('val_cosine', cosine, epoch)
                summary_writer.add_scalar('val_intersection', inter, epoch)
                summary_writer.add_scalar('val_spear', spear, epoch)
                summary_writer.add_scalar('val_tau', tau, epoch)

            # remember best (regular) vloss and corresponding epochs
            is_best = False
            if vloss < best_regular_vloss:
                is_best = True
                best_regular_vloss = min(best_regular_vloss, vloss)
                best_regular_epoch = epoch
                state_dict1 = model1.state_dict()
                state_dict2 = model2.state_dict()
            if best_regular_vloss == vloss:
                vloss_test, (cheby1, clark1, can1, kl1, cosine1, inter1, spear1, tau1) = validate(test_loader, model1, model2, args, logger)

            logger.info("{} | Set best vloss {} in ep {}".format(epoch, best_regular_vloss, best_regular_epoch))
            logger.info("   | best test vloss {} ".format(vloss_test))
            logger.info("   | best test metrics: cheby:{} clark:{} can:{} kl:{} cos:{} inter:{} spear:{} tau:{}".format(cheby1, clark1, can1, kl1, cosine1, inter1, spear1, tau1))

           
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict1,
                'best_CEloss': best_vloss,
                'optimizer' : optimizer.state_dict(),
            }, is_best=is_best, filename=os.path.join(args.output, 'checkpoint1.pth.tar'))
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict2,
                'best_CEloss': best_vloss,
                'optimizer' : optimizer.state_dict(),
            }, is_best=is_best, filename=os.path.join(args.output, 'checkpoint2.pth.tar'))

    print("Best vloss:", best_vloss)

    if summary_writer:
        summary_writer.close()    
    return 0


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    if is_best:
        torch.save(state, filename)


def train(train_label_loader, train_unlabel_loader, model1, model2, optimizer, scheduler, epoch, args, logger):
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
    losses = AverageMeter('Loss', ':5.3f')
    lr = AverageMeter('LR', ':.3e', val_only=True)
    mem = AverageMeter('Mem', ':.0f', val_only=True)
    progress = ProgressMeter(
        args.steps_per_epoch,
        [losses, lr, mem],
        prefix="Epoch: [{}/{}]".format(epoch, args.epochs))

    def get_learning_rate(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    lr.update(get_learning_rate(optimizer))
    logger.info("lr:{}".format(get_learning_rate(optimizer)))

    labeled_train_iter = iter(train_label_loader)
    unlabeled_train_iter = iter(train_unlabel_loader)

    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    # switch to train mode
    model1.train()
    model2.train()

    for batch_idx in range(args.train_iteration):
        try:
            inputs_x, targets_x = next(labeled_train_iter)
        except:
            labeled_train_iter = iter(train_label_loader)
            inputs_x, targets_x = next(labeled_train_iter)

        try:
            inputs_u, _ = next(unlabeled_train_iter)
        except:
            unlabeled_train_iter = iter(train_unlabel_loader)
            inputs_u, _ = next(unlabeled_train_iter)

        # **********************************************compute loss*************************************************************
        batch_size = inputs_x.shape[0]
        input = torch.cat([inputs_x, inputs_u], dim=0).cuda(non_blocking=True)
        targets_x = targets_x.cuda(non_blocking=True).float()
        # mixed precision ---- compute outputs
        with torch.cuda.amp.autocast(enabled=args.amp):
            pred1 = model1(input)
            pred2 = model2(input)
            pred1_l, pred1_u = torch.split(pred1[:], batch_size)
            pred2_l, pred2_u = torch.split(pred2[:], batch_size)
            loss_l = loss_fn(pred1_l, targets_x) + loss_fn(pred2_l, targets_x)
            loss_u = loss_fn(pred1_u, F.softmax(pred2_u, dim=-1)) + loss_fn(pred2_u, F.softmax(pred1_u, dim=-1))
            loss = loss_l + args.lamda * loss_u

        # *********************************************************************************************************************
        # record loss
        losses.update(loss.item(), batch_size)
        mem.update(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # one cycle learning rate
        scheduler.step()
        lr.update(get_learning_rate(optimizer))
        


        if batch_idx % args.print_freq == 0:
            progress.display(batch_idx, logger)

    return losses.avg


@torch.no_grad()
def validate(val_loader, model1, model2, args, logger):
    batch_time = AverageMeter('Time', ':5.3f')
    mem = AverageMeter('Mem', ':.0f', val_only=True)

    progress = ProgressMeter(
        len(val_loader),
        [batch_time, mem],
        prefix='Test: ')

    # switch to evaluate mode
    model1.eval()
    model2.eval()
    loss_list = []
    cheby_list = []
    clark_list = []
    can_list = []
    kl_list = []
    cosine_list = []
    inter_list = []
    spear_list = []
    tau_list = []
        
    end = time.time()
    loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
    for i, (X, y) in enumerate(val_loader):
        X = X.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True).float()
        # y = F.softmax(y, dim=-1)

        # compute output
        with torch.cuda.amp.autocast(enabled=args.amp):
            y_hat1 = model1(X)
            y_hat2 = model2(X)
            loss = loss_fn(y_hat1, y) + loss_fn(y_hat2, y)

        y_hat = (y_hat1 + y_hat2) / 2
        (cheby, clark, can, kl, cosine, inter, spear, tau) = score(y, F.softmax(y_hat, dim=-1))

        # add list
        loss_list.append(loss.detach().cpu())
        cheby_list.append(cheby.detach().cpu())
        clark_list.append(clark.detach().cpu())
        can_list.append(can.detach().cpu())
        kl_list.append(kl.detach().cpu())
        cosine_list.append(cosine.detach().cpu())
        inter_list.append(inter.detach().cpu())
        spear_list.append(spear.detach().cpu())
        tau_list.append(tau.detach().cpu())

        # record memory
        mem.update(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i, logger)

    # calculate CEloss
    loss = sum(loss_list) / len(val_loader.dataset)
    chebyloss = sum(cheby_list) / len(val_loader.dataset)
    clarkloss = sum(clark_list) / len(val_loader.dataset)
    canloss = sum(can_list) / len(val_loader.dataset)
    klloss = sum(kl_list) / len(val_loader.dataset)
    cosineloss = sum(cosine_list) / len(val_loader.dataset)
    interloss = sum(inter_list) / len(val_loader.dataset)
    spearloss = sum(spear_list) / len(val_loader.dataset)
    tauloss = sum(tau_list) / len(val_loader.dataset)
    
    print("Calculating loss:")  
    logger.info("  loss: {}".format(loss))

    return loss, (chebyloss, clarkloss, canloss, klloss, cosineloss, interloss, spearloss, tauloss)


if __name__ == '__main__':
    main()