import os
from collections import Counter
from math import ceil

import decorator
import torch
import torch.distributed as dist
import logging as logger
from torch.utils.checkpoint import checkpoint
from transformers.models.vit.modeling_vit import PatchEmbeddings

from MQSP_evaluation.ColaiSeqParalSAR import TransformerSelfAttentionRing
from MQSP_evaluation.distributed_utils import DistGroups
from distributed_transformer import DistributedSelfAttention


class ConcatSequenceOrigin(torch.autograd.Function):
    @staticmethod
    def forward(ctx, sub_seq_output, group=None):
        """
        concat seq_len output between rank
            rank0 [batch, sub_seq_length, hidden_size]
            rank1 [batch, sub_seq_length, hidden_size]
        after concat:
            rank0 [batch, seq_length, hidden_size]
        """
        ctx.group = group
        # (B, sub_seq_length, hidden_size)
        # first is token is ignore; all_gather sequence
        # print("ConcatSequenceOrigin forward", sequence_output.shape)
        seq_world_size = group.size()
        sequence_output_list = [torch.zeros_like(sub_seq_output, requires_grad=True) for _ in range(seq_world_size)]
        dist.all_gather(sequence_output_list, sub_seq_output.contiguous(), group=group)
        sequence_output = torch.cat(sequence_output_list, dim=1)
        return sequence_output

    @staticmethod
    def backward(ctx, seq_grad):
        """
        split seq_len grad between rank
            rank0 [batch, seq_length, hidden_size]
            rank1 [batch, seq_length, hidden_size]
        after concat:
            rank0 [batch, sub_seq_length, hidden_size]
        """
        split_grads = [t.contiguous() for t in torch.chunk(seq_grad, ctx.group.size(), dim=1)]
        sub_seq_grad = torch.zeros_like(split_grads[0])
        dist.reduce_scatter(sub_seq_grad, split_grads, group=ctx.group)

        # grad: (B, sub_seq_length, hidden_size)
        return sub_seq_grad, None


def change_img_size(args, config, model):
    print("change_img_size", args.max_image_size, config.image_size)
    # (change_img_size/config.patch_size)**2/sp_size should be int
    if args.max_image_size != config.image_size:
        # wrap embedding output

        model.vit.embeddings.patch_embeddings = PatchEmbeddings(
            image_size=args.max_image_size,
            patch_size=config.patch_size,
            num_channels=config.num_channels,
            embed_dim=config.hidden_size,
        )
        old_fn = model.vit.embeddings.patch_embeddings.forward
        # output = num_patches
        if args.sp_size > 1:
            sp_size = args.sp_size

            def split_sp_fn(pixel_values, *args, **kwargs):
                embedding_output = old_fn(pixel_values, *args, **kwargs)
                chunks = embedding_output.chunk(sp_size, dim=1)
                return chunks[DistGroups["sp"].rank()]

            # VitEmbedding: split on sequence dim patch_embeddings
            model.vit.embeddings.patch_embeddings.forward = split_sp_fn
        num_patches = model.vit.embeddings.patch_embeddings.num_patches
        assert num_patches % args.sp_size == 0, "Image num_patch need devided by sp_size,now :%s/%s" % (
            num_patches,
            args.sp_size,
        )
        print("split num_patches", args.max_image_size, num_patches, args.sp_size, flush=True)
        if args.sp_size > 1:  # diff rank have save position_embeddings?
            model.vit.embeddings.position_embeddings = torch.nn.Parameter(
                torch.zeros(1, num_patches // args.sp_size + 1, config.hidden_size)
            )
        else:
            model.vit.embeddings.position_embeddings = torch.nn.Parameter(
                torch.zeros(1, num_patches + 1, config.hidden_size)
            )
        config.image_size = args.max_image_size


def change_position_embedding(model, args):
    bert_emb = model.bert.embeddings
    position_weight = bert_emb.position_embeddings.weight
    if torch.distributed.get_rank() == 0:
        print("position_embeddings size:", position_weight.size(), bert_emb.position_embeddings.padding_idx)
        print(
            f"exp: method={args.sp_method} batch={args.per_device_train_batch_size} seq_length={args.max_seq_length} "
        )

    if args.max_seq_length > position_weight.size(0):
        # post_weight = torch.empty(
        #     args.max_seq_length - position_weight.size(0),
        #     position_weight.size(1),
        #     device=position_weight.device,
        #     dtype=position_weight.dtype,
        # )
        # position_embeddings -> 512*n

        # assert args.max_seq_length % position_weight.size(0) == 0, "position embedding "
        # repeat rather than random init; show that
        post_weight = position_weight.repeat(ceil(args.max_seq_length / position_weight.size(0)), 1)
        post_weight = post_weight[: args.max_seq_length, :]
        print(f"post_weight.shape={post_weight.shape}")
        # if dist.get_rank() == 0:
        #     torch.nn.init.normal_(post_weight, float(torch.mean(position_weight)), float(torch.std(position_weight)))
        # dist.broadcast(
        #     post_weight,
        #     0,
        # )

        # new_weight = torch.cat([position_weight, post_weight])
        bert_emb.position_embeddings = torch.nn.Embedding(
            args.max_seq_length,
            bert_emb.position_embeddings.weight.shape[1],
            _weight=post_weight,
            max_norm=bert_emb.position_embeddings.max_norm,
            norm_type=bert_emb.position_embeddings.norm_type,
            scale_grad_by_freq=bert_emb.position_embeddings.scale_grad_by_freq,
            padding_idx=bert_emb.position_embeddings.padding_idx,
        )
        bert_emb.register_buffer("position_ids", torch.arange(args.max_seq_length).expand((1, -1)))


def save_model(model, optimizer, args):
    if dist.get_rank() == 0 and args.output_dir is not None:
        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
        save_file = os.path.join(args.output_dir, "best_%s.th" % (args.max_seq_length))
        with open(save_file, "wb") as fout:
            torch.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, fout)


def load_model(args, model, optimizer):
    load_file = os.path.join(args.output_dir, "best_%s.th" % (args.max_seq_length))
    with open(load_file, "rb") as fin:
        state_dict = torch.load(fin, map_location=torch.cuda.current_device())
        # model =
        model.load_state_dict(state_dict["model"])
        optimizer.load_state_dict(state_dict["optimizer"])
    return model, optimizer


def tensor_diff(tensor0, tensor1, layer_number, tensor_idx, atol=1e-3, rtol=1e-3):
    tensor0_shape = list(tensor0.shape)
    tensor1_shape = list(tensor1.shape)
    if not (tensor0_shape == tensor1_shape):
        logger.warning("Layer%s tensor%s shape not equal %s %s", layer_number, tensor_idx, tensor0_shape, tensor1_shape)
        return False

    tensor0_str = "mean=%.3f std=%.3f max=%.3f" % (torch.mean(tensor0), torch.std(tensor0), torch.max(tensor0))
    tensor1_str = "mean=%.3f std=%.3f max=%.3f" % (torch.mean(tensor1), torch.std(tensor1), torch.max(tensor1))
    allclose = torch.allclose(tensor0, tensor1, atol=atol, rtol=rtol)
    if not allclose:
        abs_delta = torch.sum(torch.abs(tensor0 - tensor1))
        max_delta = torch.max(tensor0 - tensor1)
        min_delta = torch.min(tensor0 - tensor1)
        logger.warning(
            (
                f"tensor not allclose abs={abs_delta:.3f} "
                f"max={max_delta:.4f} min={min_delta:.4f} "
                f"0={tensor0_str} 1={tensor1_str}"
            )
        )
        return False


def show_diff(tensor_list0, tensor_list1, layer_number, atol=1e-3, rtol=1e-3):
    if isinstance(tensor_list0, (tuple, list)):
        length0 = len(tensor_list0)
        length1 = len(tensor_list1)
        if length0 != length1:
            logger.warning("Layer%s tensor length not equal", layer_number)
        for tensor_idx, (tensor0, tensor1) in enumerate(zip(tensor_list0, tensor_list1)):
            tensor_diff(tensor0, tensor1, layer_number, tensor_idx, atol=atol, rtol=rtol)
    else:
        tensor_diff(tensor_list0, tensor_list1, layer_number, 0)


def wrap_bert_layer_output(bert_encoder, save_tensor_path=None, load_tensor_path=None):
    assert (
        save_tensor_path or load_tensor_path
    ), "using wrap_bert_layer_output must one of (save_tensor_path,load_tensor_path)"
    this_rank = dist.get_rank()

    def wrap_output(
        i,
        old_forward,
    ):
        if save_tensor_path:

            def new_forward(*args, **kwargs):
                out_tensor = old_forward(*args, **kwargs)  # tuple or ModelOutput
                out_filename = save_tensor_path % (i, this_rank)
                with open(out_filename, "wb") as fout:
                    torch.save(out_tensor, fout)
                logger.info("saving output tensor to file %s", out_filename)
                return out_tensor

        elif load_tensor_path:

            def new_forward(*args, **kwargs):
                out_tensor = old_forward(*args, **kwargs)  # tuple or ModelOutput
                with open(load_tensor_path % (i, this_rank), "rb") as fin:
                    last_out = torch.load(fin)
                    show_diff(out_tensor, last_out, layer_number=i)
                    # torch.save(out_tensor, fout)
                return out_tensor

        return new_forward

    for i, layer in enumerate(bert_encoder.layer):
        old_forward = layer.forward
        logger.info("wrap layer%s output", i)

        new_forward = wrap_output(i, old_forward)
        layer.forward = new_forward


@decorator.decorator
def wrap_ckpt(fn, *args, **kwargs):
    return checkpoint(fn, *args, **kwargs)


def replace_bert_attn(bert_encoder):
    """replace transformers.modeling_bert.bertEncoder.layers.attention
    to TransformerSelfAttentionRing
    """
    for i, layer in enumerate(bert_encoder.layer):
        pre_attn = layer.attention
        if hasattr(pre_attn, "self"):  # BertSelfAttn
            ori_self = pre_attn.self
        elif hasattr(pre_attn, "attention"):  # ViTAttention
            ori_self = pre_attn.attention
        hidden_size = ori_self.all_head_size
        num_attention_heads = ori_self.num_attention_heads
        attention_dropout = ori_self.dropout.p
        # del ori_self
        old_qkw_weight = torch.cat([ori_self.query.weight.data, ori_self.key.weight.data, ori_self.value.weight.data])
        old_qkw_bias = torch.cat([ori_self.query.bias.data, ori_self.key.bias.data, ori_self.value.bias.data])
        layer_number = i + 1
        new_attn = TransformerSelfAttentionRing(hidden_size, num_attention_heads, attention_dropout, layer_number)
        logger.info(
            f"replace bert's layer{i}'s attn to SAR "
            f"old_qkw_weight:{old_qkw_weight.shape} to{new_attn.query_key_value.weight.shape}"
        )
        logger.info(
            f"replace bert's layer{i}'s attn to SAR "
            f"old_qkw_bias:{old_qkw_bias.shape} to{new_attn.query_key_value.bias.shape}"
        )

        new_attn.query_key_value.weight.data.copy_(old_qkw_weight)
        new_attn.query_key_value.bias.data.copy_(old_qkw_bias)
        new_attn.dense = pre_attn.output
        # transformer BertAttention: BertSelfAttention->BertSelfOutput,
        # TransformerSelfAttentionRing: BertSelfAttention->BertSelfOutput,
        layer.attention = new_attn


def replace_bert_attn_qallgather(bert_encoder, num_micro_q=4, **kwargs):
    """replace transformers.modeling_bert.bertEncoder.layers.attention
    to TransformerSelfAttentionRing
    """
    counter = Counter()
    for i, layer in enumerate(bert_encoder.layer):
        if hasattr(layer.attention, "self"):  # BertSelfAttn
            ori_self = layer.attention.self
            new_config = bert_encoder.config
            new_config.update(kwargs)
            new_self = DistributedSelfAttention(new_config, num_micro_q=num_micro_q, group=DistGroups["sp"])
            new_self.load_state_dict(ori_self.state_dict())
            layer.attention.self = new_self
            counter["bert"] += 1
        elif hasattr(layer.attention, "attention"):  # ViTAttention
            ori_self = layer.attention.attention
            new_config = bert_encoder.config
            new_config.update(kwargs)
            new_self = DistributedSelfAttention(new_config, num_micro_q=num_micro_q, group=DistGroups["sp"])
            new_self.load_state_dict(ori_self.state_dict())
            layer.attention.attention = new_self
            counter["vit"] += 1

    for layername, count in counter.items():
        logger.info(f"replace %d %s's layer {i}'s attn to DistributedSelfAttention", count, layername)


def replace_bert_attn_MQSP_ckptother(bert_encoder, num_micro_q=4, **kwargs):
    """replace transformers.modeling_bert.bertEncoder.layers.attention
    to TransformerSelfAttentionRing
    """
    for i, layer in enumerate(bert_encoder.layer):
        if hasattr(layer.attention, "self"):  # BertSelfAttn
            ori_self = layer.attention.self
            new_config = bert_encoder.config
            new_config.update(kwargs)
            new_self = DistributedSelfAttention(new_config, num_micro_q=num_micro_q, group=DistGroups["sp"])
            new_self.load_state_dict(ori_self.state_dict())
            layer.attention.self = new_self
            logger.info(f"replace bert's layer {i}'s attn to DistributedSelfAttention")
            # wrap ckpt to attn.output / feed_forward_chunk
            layer.attention.output.forward = wrap_ckpt(layer.attention.output.forward)
            layer.feed_forward_chunk = wrap_ckpt(layer.feed_forward_chunk)
            logger.info(f"ckpt wrap bert's layer {i}'s attn.output and feed_forward_chunk")

        elif hasattr(layer.attention, "attention"):  # ViTAttention
            ori_self = layer.attention.attention
            new_config = bert_encoder.config
            new_config.update(kwargs)
            new_self = DistributedSelfAttention(new_config, num_micro_q=num_micro_q, group=DistGroups["sp"])
            new_self.load_state_dict(ori_self.state_dict())
            layer.attention.attention = new_self
            logger.info(f"replace vit's layer {i}'s attn to DistributedSelfAttention")
            # wrap ckpt TODO
            raise NotImplementedError
