import argparse
import math
import os
import shutil
import sys
import time

from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import Dataset
import torchvision.transforms as transforms
# import wandb
from torch.optim.lr_scheduler import LambdaLR
from torchvision.transforms import InterpolationMode
from utils import AverageMeter, accuracy, get_parameters
from utils_auto_logit import AverageMeter, accuracy, get_parameters, compute_adjustment_auto

sys.path.append('..')
# It is imported for you to access and modify the PyTorch source code (via Ctrl+Click), more details in README.md
from torch.utils.data._utils.fetch import _MapDatasetFetcher

from relabel.utils_fkd import (ComposeWithCoords, ImageFolder_FKD_MIX,
                               RandomHorizontalFlipWithRes,
                               RandomResizedCropWithCoords, mix_aug)

class LT_Dataset(Dataset):
    
    def __init__(self, root, txt, transform=None):
        self.img_path = []
        self.labels = []
        self.transform = transform
        with open(txt) as f:
            for line in f:
                self.img_path.append(os.path.join(root, line.split()[0]))
                self.labels.append(int(line.split()[1]))
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, index):

        path = self.img_path[index]
        label = self.labels[index]
        
        with open(path, 'rb') as f:
            sample = Image.open(f).convert('RGB')
        
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, label, path

def get_args():
    parser = argparse.ArgumentParser("FKD Training on ImageNet-1K")
    parser.add_argument('--batch-size', type=int,
                        default=1024, help='batch size')
    parser.add_argument('--gradient-accumulation-steps', type=int,
                        default=1, help='gradient accumulation steps for small gpu memory')
    parser.add_argument('--start-epoch', type=int,
                        default=0, help='start epoch')
    parser.add_argument('--epochs', type=int, default=300, help='total epoch')
    parser.add_argument('-j', '--workers', default=16, type=int,
                        help='number of data loading workers')

    parser.add_argument('--train-dir', type=str, default=None,
                        help='path to training dataset')
    parser.add_argument('--val-dir', type=str,
                        default='/path/to/imagenet/val', help='path to validation dataset')
    parser.add_argument('--output-dir', type=str,
                        default='./save/1024', help='path to output dir')

    parser.add_argument('--cos', default=False,
                        action='store_true', help='cosine lr scheduler')
    parser.add_argument('--sgd', default=False,
                        action='store_true', help='sgd optimizer')
    parser.add_argument('-lr', '--learning-rate', type=float,
                        default=1.024, help='sgd init learning rate')  # checked
    parser.add_argument('--momentum', type=float,
                        default=0.875, help='sgd momentum')  # checked
    parser.add_argument('--weight-decay', type=float,
                        default=3e-5, help='sgd weight decay')  # checked
    parser.add_argument('--adamw-lr', type=float,
                        default=0.001, help='adamw learning rate')
    parser.add_argument('--adamw-weight-decay', type=float,
                        default=0.01, help='adamw weight decay')

    parser.add_argument('--model', type=str,
                        default='resnet18', help='student model name')

    parser.add_argument('--keep-topk', type=int, default=1000,
                        help='keep topk logits for kd loss')
    parser.add_argument('-T', '--temperature', type=float,
                        default=3.0, help='temperature for distillation loss')
    parser.add_argument('--fkd-path', type=str,
                        default=None, help='path to fkd label')
    parser.add_argument('--wandb-project', type=str,
                        default='Temperature', help='wandb project name')
    parser.add_argument('--wandb-api-key', type=str,
                        default=None, help='wandb api key')
    parser.add_argument('--mix-type', default=None, type=str,
                        choices=['mixup', 'cutmix', None], help='mixup or cutmix or None')
    parser.add_argument('--fkd_seed', default=42, type=int,
                        help='seed for batch loading sampler')
    parser.add_argument('--repeat', default=1, type=int,
                        help='the number of repeat for total epochs')
    parser.add_argument('--sub-epoch', default=300, type=int,
                        help='sub epoch number')
    parser.add_argument('--arch-name', type=str, default='resnet18',
                        help='arch name from pretrained torchvision models')
    # # parser.add_argument('--teacher-model-path', type=str,
    #                     default=None, help='path to teacher model')
    parser.add_argument('--auto', default=False, action='store_true',
                        help='whether to use auto adjustment')

    args = parser.parse_args()

    args.mode = 'fkd_load'
    return args

def main():
    args = get_args()

    # wandb.login(key=args.wandb_api_key)
    # wandb.init(project=args.wandb_project, name=args.output_dir.split('/')[-1])


    if not torch.cuda.is_available():
        raise Exception("need gpu to train!")


    assert os.path.exists(args.train_dir)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Data loading
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = ImageFolder_FKD_MIX(
        fkd_path=args.fkd_path,
        mode=args.mode,
        args_epoch=args.epochs,
        args_bs=args.batch_size,
        root=args.train_dir,
        transform=ComposeWithCoords(transforms=[
            RandomResizedCropWithCoords(size=224,
                                        scale=(0.08, 1),
                                        interpolation=InterpolationMode.BILINEAR),
            RandomHorizontalFlipWithRes(),
            transforms.ToTensor(),
            normalize,
        ]))

    generator = torch.Generator()
    generator.manual_seed(args.fkd_seed)
    sampler = torch.utils.data.RandomSampler(train_dataset, generator=generator)

    # only main process, no worker process
    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=args.batch_size, shuffle=(sampler is None), sampler=sampler,
    #     num_workers=0, pin_memory=True,
    #     prefetch_factor=None)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(sampler is None), sampler=sampler,
        num_workers=args.workers, pin_memory=True)

    # load validation data
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=int(args.batch_size/4), shuffle=False,
        num_workers=args.workers, pin_memory=True)
    print('load data successfully')
    model_teacher = torchvision.models.__dict__[args.arch_name](pretrained=False)
    ckpt_path="/data0/xxx/work/imagenet-lt-pretrain/output/resnet18_Imagenet1k_lt/checkpoint.pth"
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    model_teacher.load_state_dict(checkpoint["model"])

    model_teacher = nn.DataParallel(model_teacher).cuda()
    model_teacher.eval()
    # load teacher model

    if args.auto:
        train_dataset_ = ImageFolder_FKD_MIX(
            fkd_path=args.fkd_path,
            mode='fkd_save',
            args_epoch=args.epochs,
            args_bs=args.batch_size,
            root=args.train_dir,
            transform=ComposeWithCoords(transforms=[
                RandomResizedCropWithCoords(size=224,
                                            scale=(0.08, 1),
                                            interpolation=InterpolationMode.BILINEAR),
                RandomHorizontalFlipWithRes(),
                transforms.ToTensor(),
                normalize,
            ]))

        generator_ = torch.Generator()
        generator_.manual_seed(args.fkd_seed)
        sampler_ = torch.utils.data.RandomSampler(train_dataset_, generator=generator_)

        # only main process, no worker process
        # train_loader = torch.utils.data.DataLoader(
        #     train_dataset, batch_size=args.batch_size, shuffle=(sampler is None), sampler=sampler,
        #     num_workers=0, pin_memory=True,
        #     prefetch_factor=None)
        train_loader_ = torch.utils.data.DataLoader(
        train_dataset_, batch_size=args.batch_size, shuffle=(sampler_ is None), sampler=sampler_,
        num_workers=args.workers, pin_memory=False)

        dataset_imagenet_str = "ImageNet_LT"
        txt = './data/%s/%s_%s.txt'%(dataset_imagenet_str, dataset_imagenet_str, "train")
        class_number = 1000
        set_ = LT_Dataset("/data0/ImageNet", txt)
        imagenet_labels = set_.labels
        imagenet_labels = np.array(imagenet_labels)
        original_label_freq_array = np.zeros(class_number)
        original_label_freq_array[:len(np.bincount(imagenet_labels))] = np.bincount(imagenet_labels)
        original_label_freq_array = original_label_freq_array/np.sum(original_label_freq_array)

        print(f"original label freq array")
        print(original_label_freq_array)
        
        adjustment = compute_adjustment_auto(
            model_teacher=model_teacher,
            original_label_freq_array=original_label_freq_array,
            train_loader=train_loader_,
            device='cuda',
            class_number=class_number,
        )
    else:
        adjustment = np.zeros(1000)
        adjustment = torch.Tensor(adjustment).cuda()


    # load student model
    print("=> loading student model '{}'".format(args.model))
    model = torchvision.models.__dict__[args.model](pretrained=False)
    model = nn.DataParallel(model).cuda()
    model.train()

    if args.sgd:
        optimizer = torch.optim.SGD(get_parameters(model),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.AdamW(get_parameters(model),
                                      lr=args.adamw_lr,
                                      weight_decay=args.adamw_weight_decay)

    if args.cos == True:
        scheduler = LambdaLR(optimizer,
                             lambda step: 0.5 * (1. + math.cos(math.pi * step / args.epochs)) if step <= args.epochs else 0, last_epoch=-1)
    else:
        scheduler = LambdaLR(optimizer,
                             lambda step: (1.0-step/args.epochs) if step <= args.epochs else 0, last_epoch=-1)


    args.best_acc1=0
    args.optimizer = optimizer
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    i = 0
    for rep in range(args.repeat):
        generator = torch.Generator()
        generator.manual_seed(args.fkd_seed)
        sampler = torch.utils.data.RandomSampler(train_dataset, generator=generator)

        # only main process, no worker process
        # train_loader = torch.utils.data.DataLoader(
        #     train_dataset, batch_size=args.batch_size, shuffle=(sampler is None), sampler=sampler,
        #     num_workers=0, pin_memory=True,
        #     prefetch_factor=None)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(sampler is None), sampler=sampler,
            num_workers=args.workers, pin_memory=True)
        args.train_loader = train_loader
        for epoch in range(args.start_epoch, min(args.epochs, args.sub_epoch)):
            print(f"\nEpoch: {i}")

            global wandb_metrics
            wandb_metrics = {}

            train(model, args, epoch, model_teacher, adjustment)

            if i % 10 == 0 or i == args.epochs - 1:
                top1 = validate(model, args, i)
            else:
                top1 = 0

            # wandb.log(wandb_metrics)
            print(wandb_metrics)

            scheduler.step()

            # remember best acc@1 and save checkpoint
            is_best = top1 > args.best_acc1
            args.best_acc1 = max(top1, args.best_acc1)
            save_checkpoint({
                'epoch': i + 1,
                'state_dict': model.state_dict(),
                'best_acc1': args.best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict(),
            }, is_best, output_dir=args.output_dir)
            i += 1

def adjust_bn_momentum(model, iters):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = 1 / iters


def train(model, args, epoch=None, model_teacher=None, adjustment=0):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    optimizer = args.optimizer
    scheduler = args.scheduler
    loss_function_kl = nn.KLDivLoss(reduction='batchmean')

    model.train()
    t1 = time.time()
    args.train_loader.dataset.set_epoch(epoch)
    for batch_idx, batch_data in enumerate(args.train_loader):
        images, target, flip_status, coords_status = batch_data[0]
        mix_index, mix_lam, mix_bbox, soft_label = batch_data[1:]

        images = images.cuda()
        target = target.cuda()
        soft_label = soft_label.cuda().float()  # convert to float32
        soft_label = None
        images, _, _, _ = mix_aug(images, args, mix_index, mix_lam, mix_bbox)
        
        with torch.no_grad():
            soft_label = model_teacher(images).cuda().float()
            soft_label -= adjustment

        optimizer.zero_grad()
        assert args.batch_size % args.gradient_accumulation_steps == 0
        small_bs = args.batch_size // args.gradient_accumulation_steps

        # images.shape[0] is not equal to args.batch_size in the last batch, usually
        if batch_idx == len(args.train_loader) - 1:
            accum_step = math.ceil(images.shape[0] / small_bs)
        else:
            accum_step = args.gradient_accumulation_steps

        for accum_id in range(accum_step):
            partial_images = images[accum_id * small_bs: (accum_id + 1) * small_bs]
            partial_target = target[accum_id * small_bs: (accum_id + 1) * small_bs]
            partial_soft_label = soft_label[accum_id * small_bs: (accum_id + 1) * small_bs]

            output = model(partial_images)
            prec1, prec5 = accuracy(output, partial_target, topk=(1, 5))

            output = F.log_softmax(output/args.temperature, dim=1)
            partial_soft_label = F.softmax(partial_soft_label/args.temperature, dim=1)
            loss = loss_function_kl(output, partial_soft_label)
            # loss = loss * args.temperature * args.temperature
            loss = loss / args.gradient_accumulation_steps
            loss.backward()

            n = partial_images.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

        optimizer.step()



        # output = model(images)
        # prec1, prec5 = accuracy(output, target, topk=(1, 5))
        # output = F.log_softmax(output/args.temperature, dim=1)
        # soft_label = F.softmax(soft_label/args.temperature, dim=1)

        # loss = loss_function_kl(output, soft_label)
        # # loss = loss * args.temperature * args.temperature

        # n = images.size(0)
        # objs.update(loss.item(), n)
        # top1.update(prec1.item(), n)
        # top5.update(prec5.item(), n)

        # if batch_idx == 0:
        #     optimizer.zero_grad()

        # # do not support accumulate gradient, batch_size is fixed to 1024
        # assert args.gradient_accumulation_steps == 1
        # if args.gradient_accumulation_steps > 1:
        #     loss = loss / args.gradient_accumulation_steps

        # loss.backward()

        # if (batch_idx + 1) % args.gradient_accumulation_steps == 0 or batch_idx == len(args.train_loader) - 1:
        #     optimizer.step()
        #     optimizer.zero_grad()

    metrics = {
        "train/loss": objs.avg,
        "train/Top1": top1.avg,
        "train/Top5": top5.avg,
        "train/lr": scheduler.get_last_lr()[0],
        "train/epoch": epoch,}
    wandb_metrics.update(metrics)


    printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(epoch, scheduler.get_last_lr()[0], objs.avg) + \
                'Top-1 err = {:.6f},\t'.format(100 - top1.avg) + \
                'Top-5 err = {:.6f},\t'.format(100 - top5.avg) + \
                'train_time = {:.6f}'.format((time.time() - t1))
    print(printInfo)
    t1 = time.time()


def validate(model, args, epoch=None):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_function = nn.CrossEntropyLoss()

    model.eval()
    t1  = time.time()
    with torch.no_grad():
        for data, target in args.val_loader:
            target = target.type(torch.LongTensor)
            data, target = data.cuda(), target.cuda()

            output = model(data)
            loss = loss_function(output, target)

            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = 'TEST Iter {}: loss = {:.6f},\t'.format(epoch, objs.avg) + \
              'Top-1 err = {:.6f},\t'.format(100 - top1.avg) + \
              'Top-5 err = {:.6f},\t'.format(100 - top5.avg) + \
              'val_time = {:.6f}'.format(time.time() - t1)
    print(logInfo)

    metrics = {
        'val/loss': objs.avg,
        'val/top1': top1.avg,
        'val/top5': top5.avg,
        'val/epoch': epoch,
    }
    # wandb_metrics.update(metrics)
    print(wandb_metrics)

    return top1.avg

def save_checkpoint(state, is_best, output_dir=None,epoch=None):
    if epoch is None:
        path = output_dir + '/' + 'checkpoint.pth.tar'
    else:
        path = output_dir + f'/checkpoint_{epoch}.pth.tar'
    torch.save(state, path)

    if is_best:
        path_best = output_dir + '/' + 'model_best.pth.tar'
        shutil.copyfile(path, path_best)



if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method('spawn')
    main()
    # wandb.finish()