import time
import tqdm
import torch
import torch.utils.data
import argparse
import os
import sys
import logging
from functools import partial
from torchvision import datasets, transforms
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.models.vision_transformer import EncoderBlock
from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader, SequentialSampler, DistributedSampler
import math
logging.basicConfig(level=logging.INFO)
logger = logging
from timm.utils import accuracy, AverageMeter
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


HOME = os.environ["HOME"].rstrip("/")
basicpath = os.path.abspath("../VMamba/analyze").rstrip("/")
basicpath = os.path.abspath(os.path.dirname(__file__)).rstrip("/")

# this mode will greatly inference the speed!
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

from utils import ExtractFeatures, BuildModels

extract_feature = ExtractFeatures.extract_feature


def import_abspy(name="models", path="classification/"):
    import sys
    import importlib
    path = os.path.abspath(path)
    assert os.path.isdir(path)
    sys.path.insert(0, path)
    module = importlib.import_module(name)
    sys.path.pop(0)
    return module


# copied from https://github.com/microsoft/Swin-Transformer/blob/main/main.py
def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


# WARNING!!!  acc score would be inaccurate if num_procs > 1, as sampler always pads the dataset
# copied from https://github.com/microsoft/Swin-Transformer/blob/main/main.py
@torch.no_grad()
def validate(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        acc1 = reduce_tensor(acc1)
        acc5 = reduce_tensor(acc5)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg


def get_dataloader(batch_size=64, root="./val", img_size=224, sequential=True):
    size = int((256 / 224) * img_size)
    transform = transforms.Compose([
        transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
    ])
    dataset = datasets.ImageFolder(root, transform=transform)
    if sequential:
        sampler = torch.utils.data.SequentialSampler(dataset)
    else:
        sampler = torch.utils.data.DistributedSampler(dataset)
    
    data_loader = torch.utils.data.DataLoader(
        dataset, sampler=sampler,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        drop_last=False
    )
    return data_loader


def _validate(
    model: nn.Module = None, 
    freq=10, 
    amp=True, 
    img_size=224, 
    batch_size=128, 
    data_path="/dataset/ImageNet2012",
):
    class Args():
        AMP_ENABLE = amp
        PRINT_FREQ = freq
    config = Args()

    model.cuda().eval()
    model = torch.nn.parallel.DistributedDataParallel(model)
    _batch_size = batch_size
    while _batch_size > 0:
        try:
            _dataloader = get_dataloader(
                batch_size=_batch_size, 
                root=os.path.join(os.path.abspath(data_path), "val"),
                img_size=img_size,
                sequential=False,
            )
            logging.info(f"starting loop: img_size {img_size}; len(dataset) {len(_dataloader.dataset)}")
            validate(config, data_loader=_dataloader, model=model)
            break
        except:
            _batch_size = _batch_size // 2
            print(f"batch_size {_batch_size}", flush=True)


from utils import ExtractFeatures, BuildModels

extract_feature = ExtractFeatures.extract_feature


def _extract_feature(data_path="ImageNet_ILSVRC2012", start=0, end=200, step=-1, img_size=224, batch_size=16, train=True, aug=False):
    if True:
        resnet50 = BuildModels.build_resnet_mmpretrain(with_ckpt=True, remove_head=True, scale="r50", size=img_size).cuda().eval()
        deitsmall = BuildModels.build_deit_mmpretrain(with_ckpt=True, remove_head=True, scale="small", size=img_size).cuda().eval()
        vmambav0tiny = BuildModels.build_vmamba(with_ckpt=True, remove_head=True, scale="tv0").cuda().eval()
        vmambav2l5tiny = BuildModels.build_vmamba(with_ckpt=True, remove_head=True, scale="tv1").cuda().eval()
        vmambav2tiny = BuildModels.build_vmamba(with_ckpt=True, remove_head=True, scale="tv2").cuda().eval()
        convnexttiny = BuildModels.build_convnext(with_ckpt=True, remove_head=True, scale="tiny").cuda().eval()
        swintiny = BuildModels.build_swin_mmpretrain(with_ckpt=True, remove_head=True, scale="tiny", size=img_size).cuda().eval()
        hivittiny = BuildModels.build_hivit_mmpretrain(with_ckpt=True, remove_head=True, scale="tiny", size=img_size).cuda().eval()
        interntiny = BuildModels.build_intern(with_ckpt=True, remove_head=True, scale="tiny").cuda().eval()
        xcittiny = BuildModels.build_xcit(with_ckpt=True, remove_head=True, scale="tiny", size=img_size).cuda().eval()
        deitbase = BuildModels.build_deit_mmpretrain(with_ckpt=True, remove_head=True, scale="base", size=img_size).cuda().eval()

    if True:
        if step > 0:
            starts = list(range(start, end, step))
            ends = [s + step for s in starts]
            assert ends[-1] >= end
            ends[-1] = end
            print(f"multiple ranges: {starts} {ends} ==============", flush=True)
        else:
            starts, ends = [start], [end]

        for s, e in zip(starts, ends):
            extract_feature(
                backbones=dict(
                    vmambav2tiny = vmambav2tiny,
                    convnexttiny = convnexttiny,
                    swintiny = swintiny,
                    interntiny = interntiny,
                    vmambav0tiny = vmambav0tiny,
                    vmambav2l5tiny = vmambav2l5tiny,
                    deitsmall = deitsmall,
                    hivittiny = hivittiny,
                    resnet50 = resnet50,
                    xcittiny = xcittiny,
                    deitbase = deitbase,
                ), 
                dims=dict(
                    vmambav2tiny = 768,
                    convnexttiny = 768,
                    swintiny = 768,
                    interntiny = 768,
                    vmambav0tiny = 768,
                    vmambav2l5tiny = 768,
                    deitsmall = 384,
                    hivittiny = 384,
                    resnet50 = 2048,
                    xcittiny = 384,
                    deitbase = 768,
                ),
                batch_size=batch_size,
                img_size=img_size,
                data_path=data_path,
                ranges=(s, e),
                train=train,
                aug=aug,
            )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=32, help="batch size for single GPU")
    parser.add_argument('--data-path', type=str, default="ImageNet_ILSVRC2012", help='path to dataset')
    parser.add_argument('--mode', type=str, default="", help='model name')
    parser.add_argument('--func', type=str, default="", help='function')
    parser.add_argument('--start', type=int, default=0, help='start range')
    parser.add_argument('--end', type=int, default=200, help='end range')
    parser.add_argument('--step', type=int, default=-1, help='step range')
    parser.add_argument('--size', type=int, default=224, help='image size')
    parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
    parser.add_argument('--val', action="store_true", help='...')
    parser.add_argument('--aug', action="store_true", help='...')
    args = parser.parse_args()
    print(args, flush=True)
    _extract_feature(args.data_path, args.start, args.end, args.step, args.size, args.batch_size, (not args.val), args.aug)


def run_code_dist_one(func):
    if torch.cuda.device_count() > 1:
        print("WARNING!!!  acc score would be inaccurate if num_procs > 1, as sampler always pads the dataset")
        # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        # print(torch.cuda.device_count())
        exit()
        dist.init_process_group(backend='nccl', init_method='env://', world_size=-1, rank=-1)
    else:
        os.environ['MASTER_ADDR'] = "127.0.0.1"
        os.environ['MASTER_PORT'] = "61234"
        while True:
            try:
                dist.init_process_group(backend='nccl', init_method='env://', world_size=1, rank=0)
                break
            except Exception as e:
                print(e, flush=True)
                os.environ['MASTER_PORT'] = f"{int(os.environ['MASTER_PORT']) - 1}"

    torch.cuda.set_device(dist.get_rank())
    dist.barrier()
    func()


if __name__ == "__main__":
    run_code_dist_one(main)

