import argparse
import logging
import os
import sys

from open_clip import model

sys.path.insert(0, "../")
import time

import numpy as np
import torch
from src.open_clip import factory
from torch import distributed, nn, optim
from torch.utils.tensorboard import SummaryWriter
from scipy import ndimage
from training_unicom import i_vit
from training_unicom.data import SyntheticDataIter, dali_dataloader
from training_unicom.unicom import (CombinedMarginLoss, UNICOM)

parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--comm-hook", default="2d_allreduce")
parser.add_argument("--comm-hook-post-local-sgd-ratio", type=float, default=0.5)
parser.add_argument("--debug", type=int, default=0)
parser.add_argument("--epochs", type=int, default=32)
parser.add_argument("--gradient-acc", type=int, default=1)
parser.add_argument("--gradient-checkpoint", type=int, default=0)
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate.")
parser.add_argument("--lr-pfc-weight", type=float, default=1.0, help="Learning rate weight for PFC")
parser.add_argument("--lr-scheduler", default="cosine")
parser.add_argument("--grad-search-label-select", type=int, default=None)
parser.add_argument("--grad-search-threshold", type=float, default=None)
parser.add_argument("--input-size", default=224, type=int)
parser.add_argument("--is-normlize", type=int, default=1)
parser.add_argument("--imagenet-val", default=None)
parser.add_argument("--imagenet-v2", default=None)
parser.add_argument("--model", type=str, default="RN50")
parser.add_argument("--margin-loss-m1", type=float, default=1.0)
parser.add_argument("--margin-loss-m2", type=float, default=0.0)
parser.add_argument("--margin-loss-m3", type=float, default=0.4)
parser.add_argument("--margin-loss-s", type=float, default=64)
parser.add_argument("--margin-loss-filter", type=float, default=0.0)
parser.add_argument("--num-classes", type=int, required=True)
parser.add_argument("--optimizer", default="sgd")
parser.add_argument("--output-dim", type=int, default=768)
parser.add_argument("--output", required=True)
parser.add_argument("--pretrained-center", default="NULL")
parser.add_argument("--resume", default="")
parser.add_argument("--sample-rate", default=0.1, type=float)
parser.add_argument("--sample-num-feat", default=None, type=int)
parser.add_argument("--list-save-epochs", default=None, type=str)
parser.add_argument("--train-data", required=True)
parser.add_argument("--train-num-samples", type=int, required=True)
parser.add_argument("--weight-decay", type=float, default=5e-4, help="Weight decay.")
parser.add_argument("--workers", type=int, default=2)
args = parser.parse_args()

rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
distributed.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)


def main():
    os.makedirs(args.output, exist_ok=True)
    init_logging(rank, args.output)
    if rank == 0:
        summary_writer = SummaryWriter(os.path.join(args.output, "tensorboard"))
    else:
        summary_writer = None

    current_device = torch.cuda.current_device()
    if args.model == "IViT-B-32":
        backbone = i_vit.VisionTransformer(
            img_size=224, patch_size=32, in_channels=3,
            num_classes=args.output_dim, embed_dim=768,
            depth=12, num_heads=768 // 64, drop_path_rate=0.1,
            using_checkpoint=False)
    elif args.model == "IViT-B-16":
        backbone = i_vit.VisionTransformer(
            img_size=224, patch_size=16, in_channels=3,
            num_classes=args.output_dim, embed_dim=768,
            depth=12, num_heads=768 // 64, drop_path_rate=0.1,
            using_checkpoint=True)
    elif args.model == "IViT-L-14":
        backbone = i_vit.VisionTransformer(
            img_size=224, patch_size=14, in_channels=3,
            num_classes=args.output_dim, embed_dim=1024,
            depth=24, num_heads=1024 // 64, drop_path_rate=0.1,
            using_checkpoint=True)
    elif args.model == "IViT-L-14-336px":
        backbone = i_vit.VisionTransformer(
            img_size=336, patch_size=14, in_channels=3,
            num_classes=args.output_dim, embed_dim=1024,
            depth=24, num_heads=1024 // 64, drop_path_rate=0.1,
            using_checkpoint=True)
    else:
        model_clip = factory.create_model(args.model, device=current_device)
        backbone = ModelModified(model_clip, args.output_dim)

    backbone.cuda()

    backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone,
        bucket_cap_mb=32,
        find_unused_parameters=True,
        static_graph=True)

    global_step = GlobalStep()
    steps_per_epoch = args.train_num_samples // world_size // args.batch_size + 1
    steps_total = args.epochs * steps_per_epoch

    if world_size > 8:
        if args.comm_hook == "2d_allreduce":
            # register_2d_allreduce(backbone, global_step, args.gradient_acc)
            from ddp_comm_hooks import register_acc
            register_acc(backbone, global_step, args.gradient_acc)
        elif args.comm_hook == "posted_2d_allreduce":
            # start_local_sgd_iter = int(
            #     steps_total * args.comm_hook_post_local_sgd_ratio)
            # periodic_model_averager = PeriodicModelAverager(
            #     args.gradient_acc, warmup_steps=start_local_sgd_iter)
            # register_posted_2d_allreduce(
            #     backbone, global_step, args.gradient_acc, start_local_sgd_iter)
            raise ValueError("Don't use this")

    margin_loss = CombinedMarginLoss(
        args.margin_loss_s,
        args.margin_loss_m1,
        args.margin_loss_m2,
        args.margin_loss_m3,
        args.margin_loss_filter
        )

    if args.optimizer == "adamw":
        module_unicom = UNICOM(
            margin_loss, args.output_dim, args.num_classes,
            args.sample_rate, True, args.is_normlize, args.sample_num_feat)
        module_unicom.train().cuda()
        opt = torch.optim.AdamW(
            params=[
                {"params": backbone.parameters()},
                {"params": module_unicom.parameters(), "lr": args.lr * args.lr_pfc_weight}],
            lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError(f"{args.optimizer} is wrong")

    if args.pretrained_center != "NULL":
        center_path = f"{args.pretrained_center}{rank}.npy"
        assert os.path.exists(center_path)
        pt_center = torch.from_numpy(np.load(center_path)).cuda(local_rank)
        module_unicom.weight = torch.nn.Parameter(pt_center)

    if args.lr_scheduler == "cosine":
        lr_scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer=opt,
            max_lr=[args.lr, args.lr * args.lr_pfc_weight],
            steps_per_epoch=steps_per_epoch,
            epochs=args.epochs,
            pct_start=0.1,
        )
    elif args.lr_scheduler == "linear":
        lr_scheduler = optim.lr_scheduler.LinearLR(
            optimizer=opt, start_factor=1.0, end_factor=0.0,
            total_iters=args.epochs * steps_per_epoch)
    else:
        raise

    callback_func = SpeedCallBack(5, steps_total, args.batch_size)
    auto_scaler = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=200)

    if len(args.resume) > 2:
        start_epoch = 0
        state_dict = torch.load(f"{args.resume}{rank}.pt")
        pt_center = state_dict["state_dict_softmax_fc"]['weight']
        state_dict_backbone = state_dict["state_dict_backbone"]
        if args.input_size != 224:
            state_dict_backbone = zoom_state_dict(state_dict_backbone, args.input_size)

        backbone.load_state_dict(state_dict_backbone, strict=True)
        module_unicom.weight = torch.nn.Parameter(pt_center)

        del state_dict
    else:
        start_epoch = 0

    if args.debug:
        train_loader = SyntheticDataIter(args.batch_size, args.input_size, local_rank)
    else:
        train_loader = dali_dataloader(args)

    for epoch in range(start_epoch, args.epochs):
        for _, (img, local_labels) in enumerate(train_loader):
            with torch.cuda.amp.autocast(False):
                local_embeddings = backbone(img)
            local_embeddings.float()
            local_labels = local_labels.cuda()
            loss = module_unicom(local_embeddings, local_labels)
            auto_scaler.scale(loss).backward()

            if global_step.step % args.gradient_acc == 0:
                auto_scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(backbone.parameters(), 1)
                auto_scaler.step(opt)
                auto_scaler.update()
                opt.zero_grad()

            lr_scheduler.step()
            global_step.step += 1

            if args.comm_hook == "posted_2d_allreduce":
                # periodic_model_averager.average_parameters(backbone.parameters())
                raise ValueError("Don't use this")

            with torch.no_grad():
                callback_func(
                    lr_scheduler,
                    float(loss),
                    global_step.step,
                    auto_scaler.get_scale())

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        tag="loss", scalar_value=loss.item(),
                        global_step=global_step.step)
                    summary_writer.add_scalar(
                        tag="lr_backbone", scalar_value=lr_scheduler.get_last_lr()[0],
                        global_step=global_step.step)

        train_loader.reset()

        checkpoint = {
            "epoch": epoch + 1,
            "global_step": global_step.step,
            "state_dict_backbone": backbone.state_dict(),
            "state_dict_softmax_fc": module_unicom.state_dict(),
            # "state_dict_optimizer": opt.state_dict(),
            # "state_dict_lr_scheduler": lr_scheduler.state_dict()
        }

        if args.list_save_epochs is not None:
            list_epochs = args.list_save_epochs.split(",")
            list_epochs = [int(x) for x in list_epochs]
            if isinstance(list_epochs, list) and epoch + 1 in list_epochs:
                fname = f"epoch_{epoch + 1}_checkpoint_gpu_{rank}.pt"
                torch.save(checkpoint, os.path.join(args.output, fname))

        torch.save(
            obj=checkpoint,
            f=os.path.join(args.output, f"checkpoint_gpu_{rank}.pt"))

        if rank == 0:
            torch.save(
                obj=backbone.module,
                f=os.path.join(args.output, "ModelModified.pt"))
    if summary_writer is not None:
        summary_writer.close()


def init_logging(rank, models_root):
    if rank == 0:
        log_root = logging.getLogger()
        log_root.setLevel(logging.INFO)
        formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
        handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
        handler_stream = logging.StreamHandler(sys.stdout)
        handler_file.setFormatter(formatter)
        handler_stream.setFormatter(formatter)
        log_root.addHandler(handler_file)
        log_root.addHandler(handler_stream)
        log_root.info('rank_id: %d' % rank)


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class ModelModified(torch.nn.Module):
    def __init__(self, model_clip: factory.CLIP, output_dim) -> None:
        super().__init__()
        self.model_clip = model_clip
        self.linear = nn.Linear(self.model_clip.visual.output_dim, output_dim)
        self.bn = nn.BatchNorm1d(output_dim)

    def forward(self, image):
        return self.bn(self.linear(self.model_clip.encode_image(image)))


class GlobalStep:
    def __init__(self, step: int = 0):
        self.step = int(step)

    def update(self):
        self.step += 1


class SpeedCallBack(object):
    def __init__(self, frequent, steps_total, batch_size):
        self.batch_size = batch_size
        self.frequent = frequent
        self.steps_total = steps_total
        self.loss_metric = AverageMeter()

        self.rank = int(os.environ["RANK"])
        self.world_size = int(os.environ["WORLD_SIZE"])
        self.time_start = time.time()
        self.init = False
        self.tic = 0

    def __call__(
            self,
            lr_scheduler: optim.lr_scheduler._LRScheduler,
            loss,
            global_step,
            scale):
        assert isinstance(loss, float)

        self.loss_metric.update(loss)
        if global_step > 0 and global_step % self.frequent == 0:
            if self.init:
                try:
                    speed: float = (
                        self.frequent * self.batch_size / (time.time() - self.tic)
                    )
                    self.tic = time.time()
                    speed_total = speed * self.world_size
                except ZeroDivisionError:
                    speed = float("inf")
                    speed_total = float("inf")

                loss_str_format = f"{self.loss_metric.avg :.3f}"
                self.loss_metric.reset()

                time_now = (time.time() - self.time_start) / 3600
                time_total = time_now / ((global_step + 1) / self.steps_total)
                time_for_end = time_total - time_now
                lr_1 = lr_scheduler.get_last_lr()[0]
                lr_2 = lr_scheduler.get_last_lr()[1]
                msg = f"rank:{int(speed) :d} "
                msg += f"total:{int(speed_total) :d} "
                msg += f"lr:[{lr_1 :.8f}][{lr_2 :.8f}] "
                msg += f"step:{global_step :d} "
                msg += f"amp:{int(scale) :d} "
                msg += f"required:{time_for_end :.1f} hours "
                msg += loss_str_format

                if self.rank == 0:
                    logging.info(msg)
            else:
                self.init = True
                self.tic = time.time()


def zoom_state_dict(state_dict, input_size):
    # ZOOM POS
    pos_embed = state_dict["module.pos_embed"]
    pos_embed = pos_embed.cpu().numpy()
    _, posemb_grid = pos_embed[:, :0], pos_embed[0]

    num_token_new = (input_size // 14) ** 2
    num_token_old = len(posemb_grid)
    gs_old = int(np.sqrt(num_token_old))
    gs_new = int(np.sqrt(num_token_new))
    posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
    zoom = (gs_new / gs_old, gs_new / gs_old, 1)
    print(posemb_grid.shape)
    print(zoom)
    posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
    posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)

    device = state_dict.pop("module.pos_embed").device
    state_dict["module.pos_embed"] = torch.from_numpy(
        posemb_grid).to(device)
    # ZOOM FC
    weight_last = state_dict.pop("module.feature.0.weight")
    print(weight_last.shape)
    weight_last = weight_last.cpu().numpy()
    dim = weight_last.shape[0]
    weight_last = np.transpose(weight_last, (1, 0))
    weight_last = weight_last.reshape(gs_old, gs_old, dim, dim)
    zoom_weight = (gs_new / gs_old, gs_new / gs_old, 1, 1)
    weight_last = ndimage.zoom(weight_last, zoom_weight, order=1)
    weight_last = weight_last.reshape(-1, dim)
    weight_last = np.transpose(weight_last, (1, 0))
    state_dict["module.feature.0.weight"] = torch.from_numpy(
        weight_last).to(device)
    return state_dict


if __name__ == "__main__":
    main()
