import argparse
import logging
import os
import sys
import time
from typing import List

import numpy as np
import torch
from torch import distributed
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import \
    fp16_compress_hook
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter


from lr_scheduler import PolynomialLRWarmup
from partial_fc_v4 import MultiFilterArcLoss, PartialFC_V4
from properties import *
from vit import ViT_L_14

if os.getenv("USE_TORCHACC"):
    from torchacc.torch_xla.amp import GradScaler
else:
    from torch.cuda.amp import GradScaler


parser = argparse.ArgumentParser(description='')
parser.add_argument('--backward_passes_per_step', type=int, default=1)
parser.add_argument('--debug',                    type=int, default=0)
parser.add_argument('--dataloader-type',          default="dali")
parser.add_argument('--list_batch_size',          default="128")
parser.add_argument('--list_dataset',             default='LAION400M')
parser.add_argument('--list_filter',              default='0')
parser.add_argument('--list_head_name',           default=None)
parser.add_argument('--list_margin',              default='0.3')
parser.add_argument('--list_scale',               default='32')
parser.add_argument('--list_num_class',           default=None)
parser.add_argument('--list_sample_rate',         default="1")
parser.add_argument('--list_lr_pfc_weight',       default="1")
parser.add_argument('--list_mean',                default="103.53,116.28,123.675")
parser.add_argument('--list_std',                 default="57.375,57.375,57.375")
parser.add_argument('--embedding_size',           type=int, default=512)
parser.add_argument('--gradient_checkpoint',      type=int, default=0)
parser.add_argument('--lr',                       type=float, default=1e-3)
parser.add_argument('--local_rank',               type=int, default=0)
parser.add_argument('--input_gray',               type=int, default=0)
parser.add_argument('--image_size',               default="224")
parser.add_argument('--momentum',                 type=float, default=0.9)
parser.add_argument('--num_epochs',               type=int, default=32)
parser.add_argument('--num_pos',                  type=int, default=8)
parser.add_argument('--num_class',                type=int, default=1000000)
parser.add_argument('--opt',                      default='adamw')
parser.add_argument('--output',                   default='output')
parser.add_argument('--random_diff',              type=int, default=0,)
parser.add_argument('--resume_dir',               default='null')

parser.add_argument('--save_pfc',                 type=int, default=1)
parser.add_argument('--frequent',                 type=int, default=100)
parser.add_argument('--warmup_ratio',             type=float, default=0.2)
parser.add_argument('--weight_decay',             type=float, default=5e-4)
parser.add_argument('--workers',                  type=int, default=2)
parser.add_argument('--ckpt_interval',            type=int, default=200)
args = parser.parse_args()

if os.getenv("USE_TORCHACC"):
    import torchacc.torch_xla.core.xla_model as xm
    import torchacc.torch_xla.distributed.xla_backend
    local_rank = args.local_rank
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    distributed.init_process_group(backend="xla")
    device = xm.xla_device()
    xm.set_replication(device, [device])
    device = torch.device(device)
else:
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    distributed.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)
    rank = int(os.getenv("RANK", "0"))

torch.backends.cudnn.benchmark = True

os.makedirs(args.output, exist_ok=True)
if rank == 0:
    log = logging.getLogger()
    formatter = logging.Formatter(f"rank-id:{rank}:%(asctime)s-%(message)s")
    file_handler = logging.FileHandler(os.path.join(args.output, "training.log"))
    stream_handler = logging.StreamHandler(sys.stdout)
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)
    log.addHandler(file_handler)
    log.addHandler(stream_handler)
    log.setLevel(logging.INFO)


def main():
    args.image_size = [int(x) for x in args.image_size.split(",")]
    if len(args.image_size) == 1:
        args.image_size = args.image_size * 2

    args.list_dataset = [eval(x) for x in args.list_dataset.split(",")]
    args.num_head = 1

    # sample list
    args.list_sample_rate = args.list_sample_rate.split(",")
    args.list_sample_rate = [float(x) for x in args.list_sample_rate]
    if len(args.list_sample_rate) == 1:
        args.list_sample_rate = args.list_sample_rate * args.num_head

    # margin list
    args.list_margin = args.list_margin.split(",")
    args.list_margin = [float(x) for x in args.list_margin]
    if len(args.list_margin) == 1:
        args.list_margin = args.list_margin * args.num_head

    # scale list
    args.list_scale = args.list_scale.split(",")
    args.list_scale = [float(x) for x in args.list_scale]
    if len(args.list_scale) == 1:
        args.list_scale = args.list_scale * args.num_head

    args.list_filter = args.list_filter.split(",")
    args.list_filter = [float(x) for x in args.list_filter]
    if len(args.list_filter) == 1:
        args.list_filter = args.list_filter * args.num_head

    args.list_lr_pfc_weight = args.list_lr_pfc_weight.split(",")
    args.list_lr_pfc_weight = [float(x) for x in args.list_lr_pfc_weight]
    if len(args.list_lr_pfc_weight) == 1:
        args.list_lr_pfc_weight = args.list_lr_pfc_weight * args.num_head

    args.batch_size = int(args.list_batch_size.split(",")[0])
    args.list_prefix = [x.prefix for x in args.list_dataset]

    assert args.list_num_class is not None

    args.list_num_class = [int(x) for x in args.list_num_class.split(",")]


    if args.list_head_name is None:
        args.list_head_name = [f"{args.list_dataset[0].name}_{x}" for x in range(args.num_head)]

    args.num_examples = args.list_dataset[0].num_example
    args.total_steps = int(args.num_examples * args.num_epochs / args.batch_size / world_size)

    if rank == 0:
        for arg in vars(args):
            msg = f"{format(arg, '<30')}  {format(str(getattr(args, arg)))}"
            logging.info(msg)

    backbone: torch.nn.Module = ViT_L_14()
    # if hasattr(torch, "compile"):
    #     backbone = torch.compile(backbone)

    backbone.train().cuda()

    list_module_pfc: List[PartialFC_V4] = []

    backbone_parameters = filter(lambda p: p.requires_grad, backbone.parameters())
    if args.opt == "ref_lars":
        parameters: List[dict] = [{"params": backbone_parameters, "norm_clip": 5}, ]
    else:
        parameters: List[dict] = [{"params": backbone_parameters}, ]

    for head_id, num_class in enumerate(args.list_num_class):
        head_name = args.list_head_name[head_id]
        arcloss_s = args.list_scale[head_id]
        arcloss_m = args.list_margin[head_id]
        arcloss_f = args.list_filter[head_id]
        sample_rate = args.list_sample_rate[head_id]

        margin_loss = MultiFilterArcLoss(
            arcloss_s,
            arcloss_m,
            arcloss_f)

        partial_fc = PartialFC_V4(
            margin_loss,
            args.embedding_size,
            num_class,
            sample_rate,
            is_normlize=True
            )

        partial_fc.train().cuda()

        if args.resume_dir != "null":
            arr_path = os.path.join(args.resume_dir, f"{head_name}_{rank}_lastest.npy")
            if os.path.exists(arr_path):
                np_arr = np.load(arr_path).astype(np.float32)
                pt_arr = torch.from_numpy(np_arr).cuda()
                partial_fc.weight = torch.nn.Parameter(pt_arr)
            else:
                logging.warning(f"loading {arr_path} error")

        list_module_pfc.append(partial_fc)
        if args.opt == "ref_lars":
            parameters.append({
                "params": partial_fc.parameters(),
                "lr": args.lr * args.list_batch_size[head_id] / args.batch_size,
                "norm_clip": 100
                })
        elif args.opt == "adamw":
            parameters.append({
                "params": partial_fc.parameters(),
                "lr": args.lr * args.list_lr_pfc_weight[head_id]})
        else:
            raise ValueError(f"{args.opt} not support!")

    if args.opt == "sgd":
        raise ValueError("SGD not support!")

    elif args.opt == "adamw":
        if os.getenv("USE_TORCHACC"):
            from torchacc.torch_xla.amp import syncfree
            optimizer_cls = syncfree.AdamW
        else:
            optimizer_cls = torch.optim.AdamW

        opt = optimizer_cls(
            parameters,
            lr=args.lr,
            weight_decay=args.weight_decay)
        lr_scheduler = PolynomialLRWarmup(
            opt,
            int(args.total_steps * args.warmup_ratio),
            args.total_steps,
            2)
    else:
        raise

    if args.resume_dir != "null":
        path = os.path.join(args.resume_dir, "model_lastest.pt")
        if os.path.exists(path):
            state_dict = torch.load(path, "cpu")
            backbone.load_state_dict(state_dict, strict=True)
        else:
            logging.warning(f"loading {path} error")
    if hasattr(backbone, "gradient_checkpoint") and args.gradient_checkpoint:
        backbone.gradient_checkpoint = True

    if os.getenv("USE_TORCHACC"):
        backbone_ddp = backbone
    else:
        backbone_ddp = torch.nn.parallel.DistributedDataParallel(
            module=backbone,
            broadcast_buffers=False,
            device_ids=[local_rank],
            bucket_cap_mb=64,
            find_unused_parameters=True,
            static_graph=True)

    backbone_ddp.register_comm_hook(None, fp16_compress_hook)
    if args.debug:
        train_iter = SyntheticDataIter(
            args.batch_size,
            args.image_size,
            num_label=8,
            )
    else:
        from data import MultiRecDALIWarper
        train_iter = MultiRecDALIWarper(
            args.list_dataset[0].prefix,
            args.batch_size,
            args.image_size,
            num_workers=args.workers)

    batch_end_callback = BatchEndCallBack(
        frequent=args.frequent, list_head_name=args.list_head_name, output=args.output,
        total_steps=args.total_steps, batch_size=args.batch_size)

    list_num_epoch = [0, ] * args.num_head
    end_of_batch = False
    global_step = 0

    amp = GradScaler(init_scale=512, growth_interval=100)
    for img, label in train_iter:

        with torch.cuda.amp.autocast(True):
            embedding = backbone_ddp(img)
        embedding = embedding.float()

        list_loss = []
        list_loss_float = []
        label = label.long().clone()
        for head_id, pfc in enumerate(list_module_pfc):
            label_head = label.clone()
            label_head = label_head[:, :8].clone()
            label_head = label_head.contiguous()
            head_loss = pfc(embedding, label_head)
            list_loss.append(head_loss)
            list_loss_float.append(head_loss.item())

        amp.scale(sum(list_loss)).backward()

        if global_step % args.backward_passes_per_step == 0:
            if os.getenv("USE_TORCHACC"):
                import torchacc.torch_xla.core.xla_model as xm
                gradients = xm._fetch_gradients(optimizer)
                xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())

            if args.opt != "ref_lars":
                amp.unscale_(opt)
                clip_grad_norm_(backbone_ddp.parameters(), max_norm=5, norm_type=2)
                for pfc in list_module_pfc:
                    clip_grad_norm_(pfc.parameters(), max_norm=5, norm_type=2)
            amp.step(opt)
            amp.update()
            opt.zero_grad()
        lr_scheduler.step()

        batch_end_callback(
            global_step, lr_scheduler, list_loss_float, list_num_epoch, amp)

        global_step += 1

        if global_step % args.ckpt_interval == 0:
            if rank == 0:
                sd_b = backbone_ddp.module.state_dict()
                path_sd_b = os.path.join(args.output, "model_lastest.pt")
                torch.save(sd_b, path_sd_b)

            if args.save_pfc:
                for head_id, head_name in enumerate(args.list_head_name):
                    # np_pfc numpy weight in this rank for pfc
                    pt_pfc = list_module_pfc[head_id].state_dict()["weight"]
                    np_pfc = pt_pfc.cpu().numpy()
                    path_np_pfc = os.path.join(args.output, f"{head_name}_{rank}_lastest.npy")
                    np.save(path_np_pfc, np_pfc)

            save_for_resume = {}
            save_for_resume["lr"] = lr_scheduler.get_lr()[0]
            save_for_resume["global_step"] = global_step
            save_for_resume["epoch"] = 0
            save_for_resume["rec_idx_in_epoch"] = 0

        if global_step == args.total_steps:
            if rank == 0:
                torch.save(backbone_ddp.module.state_dict(), os.path.join(args.output, "final_model.pt"))
            if args.save_pfc:
                for head_id, head_name in enumerate(args.list_head_name):
                    np.save(
                        file=os.path.join(args.output, f"final_{head_name}_{rank}.npy"),
                        arr=list_module_pfc[head_id].state_dict()["weight"].cpu().numpy())
            break


@torch.no_grad()
class SyntheticDataIter(object):
    def __init__(self, batch_size, image_size, num_label=1):
        data = torch.randint(
            low=0, high=255,  size=(batch_size, 3, image_size[0], image_size[1]),
            dtype=torch.float32).cuda()
        data[:, 0, :, :] -= 123.
        data[:, 1, :, :] -= 116.
        data[:, 2, :, :] -= 103.
        data *= 0.01
        data = data.contiguous()
        label = torch.zeros(size=(batch_size, num_label), dtype=torch.long).cuda()

        self.tensor_data = data
        self.tensor_label = label

    def __next__(self):
        return self.tensor_data, self.tensor_label

    def __iter__(self):
        return self

    def reset(self):
        return


class BatchEndCallBack(object):
    def __init__(self,
                 frequent: int,
                 list_head_name: List[str],
                 output: str,
                 total_steps: int,
                 batch_size: int):
        self.frequent: int = frequent
        self.list_head_name: List[str] = list_head_name
        self.output: str = output
        self.total_steps: int = total_steps
        self.batch_size: int = batch_size

        self.num_head = len(self.list_head_name)
        self.time_start = time.time()
        self.list_loss_metric = [ScalaMetric() for x in self.list_head_name]
        self.init = False
        self.tic = 0
        self.summary = None
        if rank == 0:
            self.summary = SummaryWriter(os.path.join(output, "tensorboard"))

    def __call__(self,
                 global_step: int,
                 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
                 list_loss_float: List[float],
                 list_num_epoch: List[int],
                 amp: GradScaler):

        for i in range(self.num_head):
            self.list_loss_metric[i].update(list_loss_float[i])

        if global_step > 0 and global_step % self.frequent == 0:
            if rank == 0:
                self.summary.add_scalar("backbone_lr", lr_scheduler.get_last_lr()[0], global_step, new_style=True)
            if self.init:
                try:
                    speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
                    self.tic = time.time()
                    speed_total = speed * world_size
                except ZeroDivisionError:
                    speed = float('inf')
                    speed_total = float('inf')

                loss_str_format = ""
                for head_id, name in enumerate(self.list_head_name):
                    if rank == 0:
                        self.summary.add_scalar(f"loss_{name}", self.list_loss_metric[head_id].avg, global_step, new_style=True)
                        self.summary.add_scalar(f"lr_{name}", lr_scheduler.get_last_lr()[head_id + 1], global_step, new_style=True)
                    _ = "\n"
                    _ += format(f"name: {self.list_head_name[head_id]}", "<20")
                    _ += format(f"epoch: {list_num_epoch[head_id]}", "<20")
                    _ += format(f"lr: {lr_scheduler.get_last_lr()[head_id + 1] :.4f}", "<20")
                    _ += format(f"loss: {self.list_loss_metric[head_id].avg :.2f}", "<20")

                    loss_str_format += _
                    self.list_loss_metric[head_id].reset()

                time_now = (time.time() - self.time_start) / 3600
                time_total = time_now / ((global_step + 1) / self.total_steps)
                time_for_end = time_total - time_now
                msg = "rank %.2f total %.2f its/s lr: %f step: %d required: %1.f hours scale: %d %s" % (
                    speed,
                    speed_total,
                    lr_scheduler.get_last_lr()[0],
                    global_step,
                    time_for_end,
                    amp.get_scale(),
                    loss_str_format,)

                if rank == 0:
                    logging.info(msg)
                    self.summary.flush()
            else:
                self.init = True
                self.tic = time.time()


class ScalaMetric(object):
    def __init__(self):
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



def hierarchical_allreduce_hook(
    process_group_state: tuple,
    bucket: distributed.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    # ReduceScatter - Parallelized MPI Allreduce - NCCL Allgather
    local_world_size = 8
    reduce_world_size = 8
    rank = distributed.get_rank()
    device = torch.cuda.current_device()
    list_local_group: List[distributed.ProcessGroup] = process_group_state[0]
    list_reduce_group: List[distributed.ProcessGroup] = process_group_state[1]
    group_local = list_local_group[rank // local_world_size]
    group_reduce = list_reduce_group[rank % local_world_size]
    assert len(process_group_state) == 4, ""
    rest = 0
    if bucket.buffer().size(0) % 8 == 0:
        tensor = bucket.buffer()
    else:
        rest = 8 - bucket.buffer().size(0) % 8
        tensor = torch.zeros(bucket.buffer().size(0) + rest, device=device)
        tensor[: -rest] = bucket.buffer()
    assert tensor.size(0) % 8 == 0
    tensor.div_(local_world_size)
    tensor_each = torch.zeros(tensor.size(0) // local_world_size, device=tensor.device)
    fut: torch.futures.Future = distributed.reduce_scatter(
        output=tensor_each,
        input_list=list(tensor.chunk(local_world_size)),
        group=group_local,
        async_op=True).get_future()
    def _fut_allreduce(fut):
        tensor_reduce_scatter = fut.wait()[0]
        compressed_tensor = tensor_reduce_scatter.to(torch.float16).div_(reduce_world_size)
        fut = distributed.all_reduce(
            tensor=compressed_tensor,
            op=distributed.ReduceOp.SUM,
            group=group_reduce,
            async_op=True).get_future()
        return fut.wait()
    def _fut_allgather(fut):
        tensor_allreduce: torch.Tensor = fut.wait()[0].float()
        final_tensor = torch.zeros_like(tensor)
        fut = distributed.all_gather(
            list(final_tensor.chunk(local_world_size)), tensor_allreduce,
            group=group_local, async_op=True).get_future()
        return fut.wait()
    def _output(fut):
        gather_tensor = fut.wait()[0]
        gather_tensor = torch.reshape(gather_tensor, tensor.size())
        if rest != 0:
            gather_tensor = gather_tensor[: -rest]
        buffer = bucket.buffer()
        buffer.copy_(gather_tensor)
        return buffer
    return fut.then(_fut_allreduce).then(_fut_allgather).then(_output)


def register_2d_allreduce(
    model):
    assert distributed.is_initialized()
    world_size = distributed.get_world_size()
    list_group_local = []
    list_group_reduce = []
    for idx_node in range(world_size // 8):
        list_group_local.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_local_rank in range(8)], backend="nccl"))
    for idx_local_rank in range(8):
        list_group_reduce.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_node in range(world_size // 8)], backend="nccl"))
    if distributed.get_rank() == 0:
        import logging
        logging.info(list_group_local)
        logging.info(list_group_reduce)

    model.register_comm_hook((list_group_local, list_group_reduce, None, None), hierarchical_allreduce_hook)

if __name__ == '__main__':
    main()
