from collections import defaultdict
from functools import partial, wraps

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from scipy import interpolate


def max_stack(tensors):
    if len(tensors) == 1:
        return tensors[0]
    return torch.stack(tensors, dim=-1).max(dim=-1).values


def last_stack(tensors):
    return tensors[-1]


def first_stack(tensors):
    return tensors[0]


def softmax_stack(tensors, temperature=1.0):
    if len(tensors) == 1:
        return tensors[0]
    return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1)


def mean_stack(tensors):
    if len(tensors) == 1:
        return tensors[0]
    return torch.stack(tensors, dim=-1).mean(dim=-1)


def sum_stack(tensors):
    if len(tensors) == 1:
        return tensors[0]
    return torch.stack(tensors, dim=-1).sum(dim=-1)


def convert_module_to_f16(l):
    """
    Convert primitive modules to float16.
    """
    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        l.weight.data = l.weight.data.half()
        if l.bias is not None:
            l.bias.data = l.bias.data.half()


def convert_module_to_f32(l):
    """
    Convert primitive modules to float32, undoing convert_module_to_f16().
    """
    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        l.weight.data = l.weight.data.float()
        if l.bias is not None:
            l.bias.data = l.bias.data.float()


def format_seconds(seconds):
    minutes, seconds = divmod(seconds, 60)
    hours, minutes = divmod(minutes, 60)
    return f"{hours:d}:{minutes:02d}:{seconds:02d}"


def get_params(module, lr, wd):
    skip_list = {}
    skip_keywords = {}
    if hasattr(module, "no_weight_decay"):
        skip_list = module.no_weight_decay()
    if hasattr(module, "no_weight_decay_keywords"):
        skip_keywords = module.no_weight_decay_keywords()
    has_decay = []
    no_decay = []
    for name, param in module.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if (
            (name in skip_list)
            or any((kw in name for kw in skip_keywords))
            or len(param.shape) == 1
        ):
            # if (name in skip_list) or any((kw in name for kw in skip_keywords)):
            # print(name, skip_keywords)
            no_decay.append(param)
        else:
            has_decay.append(param)

    group1 = {
        "params": has_decay,
        "weight_decay": wd,
        "lr": lr,
        "weight_decay_init": wd,
        "weight_decay_base": wd,
        "lr_init": lr,
        "lr_base": lr,
    }
    group2 = {
        "params": no_decay,
        "weight_decay": 0.0,
        "lr": lr,
        "weight_decay_init": 0.0,
        "weight_decay_base": 0.0,
        "weight_decay_final": 0.0,
        "lr_init": lr,
        "lr_base": lr,
    }
    return [group1, group2], [lr, lr]


def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
    if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"):
        return 0
    elif var_name.startswith("patch_embed"):
        return 0
    elif var_name.startswith("layers"):
        if var_name.split(".")[2] == "blocks":
            stage_id = int(var_name.split(".")[1])
            layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id])
            return layer_id + 1
        elif var_name.split(".")[2] == "downsample":
            stage_id = int(var_name.split(".")[1])
            layer_id = sum(layers_per_stage[: stage_id + 1])
            return layer_id
    else:
        return num_max_layer - 1


def get_params_layerdecayswin(module, lr, wd, ld):
    skip_list = {}
    skip_keywords = {}
    if hasattr(module, "no_weight_decay"):
        skip_list = module.no_weight_decay()
    if hasattr(module, "no_weight_decay_keywords"):
        skip_keywords = module.no_weight_decay_keywords()
    layers_per_stage = module.depths
    num_layers = sum(layers_per_stage) + 1
    lrs = []
    params = []
    for name, param in module.named_parameters():
        if not param.requires_grad:
            print(f"{name} frozen")
            continue  # frozen weights
        layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage)
        lr_cur = lr * ld ** (num_layers - layer_id - 1)
        # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"):
        if (name in skip_list) or any((kw in name for kw in skip_keywords)):
            wd_cur = 0.0
        else:
            wd_cur = wd
        params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur})
        lrs.append(lr_cur)
    return params, lrs


def log(t, eps: float = 1e-5):
    return torch.log(t.clamp(min=eps))


def l2norm(t):
    return F.normalize(t, dim=-1)


def exists(val):
    return val is not None


def identity(t, *args, **kwargs):
    return t


def divisible_by(numer, denom):
    return (numer % denom) == 0


def first(arr, d=None):
    if len(arr) == 0:
        return d
    return arr[0]


def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d


def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)

    return inner


def once(fn):
    called = False

    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)

    return inner


def _many(fn):
    @wraps(fn)
    def inner(tensors, pattern, **kwargs):
        return (fn(tensor, pattern, **kwargs) for tensor in tensors)

    return inner


rearrange_many = _many(rearrange)
repeat_many = _many(repeat)
reduce_many = _many(reduce)


def load_pretrained(state_dict, checkpoint):
    checkpoint_model = checkpoint["model"]
    if any([True if "encoder." in k else False for k in checkpoint_model.keys()]):
        checkpoint_model = {
            k.replace("encoder.", ""): v
            for k, v in checkpoint_model.items()
            if k.startswith("encoder.")
        }
        print("Detect pre-trained model, remove [encoder.] prefix.")
    else:
        print("Detect non-pre-trained model, pass without doing anything.")
    print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
    checkpoint = load_checkpoint_swin(state_dict, checkpoint_model)


def load_checkpoint_swin(model, checkpoint_model):
    state_dict = model.state_dict()
    # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
    all_keys = list(checkpoint_model.keys())
    for key in all_keys:
        if "relative_position_bias_table" in key:
            relative_position_bias_table_pretrained = checkpoint_model[key]
            relative_position_bias_table_current = state_dict[key]
            L1, nH1 = relative_position_bias_table_pretrained.size()
            L2, nH2 = relative_position_bias_table_current.size()
            if nH1 != nH2:
                print(f"Error in loading {key}, passing......")
            else:
                if L1 != L2:
                    print(f"{key}: Interpolate relative_position_bias_table using geo.")
                    src_size = int(L1**0.5)
                    dst_size = int(L2**0.5)

                    def geometric_progression(a, r, n):
                        return a * (1.0 - r**n) / (1.0 - r)

                    left, right = 1.01, 1.5
                    while right - left > 1e-6:
                        q = (left + right) / 2.0
                        gp = geometric_progression(1, q, src_size // 2)
                        if gp > dst_size // 2:
                            right = q
                        else:
                            left = q

                    # if q > 1.090307:
                    #     q = 1.090307

                    dis = []
                    cur = 1
                    for i in range(src_size // 2):
                        dis.append(cur)
                        cur += q ** (i + 1)

                    r_ids = [-_ for _ in reversed(dis)]

                    x = r_ids + [0] + dis
                    y = r_ids + [0] + dis

                    t = dst_size // 2.0
                    dx = np.arange(-t, t + 0.1, 1.0)
                    dy = np.arange(-t, t + 0.1, 1.0)

                    print("Original positions = %s" % str(x))
                    print("Target positions = %s" % str(dx))

                    all_rel_pos_bias = []

                    for i in range(nH1):
                        z = (
                            relative_position_bias_table_pretrained[:, i]
                            .view(src_size, src_size)
                            .float()
                            .numpy()
                        )
                        f_cubic = interpolate.interp2d(x, y, z, kind="cubic")
                        all_rel_pos_bias.append(
                            torch.Tensor(f_cubic(dx, dy))
                            .contiguous()
                            .view(-1, 1)
                            .to(relative_position_bias_table_pretrained.device)
                        )

                    new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
                    checkpoint_model[key] = new_rel_pos_bias

    # delete relative_position_index since we always re-init it
    relative_position_index_keys = [
        k for k in checkpoint_model.keys() if "relative_position_index" in k
    ]
    for k in relative_position_index_keys:
        del checkpoint_model[k]

    # delete relative_coords_table since we always re-init it
    relative_coords_table_keys = [
        k for k in checkpoint_model.keys() if "relative_coords_table" in k
    ]
    for k in relative_coords_table_keys:
        del checkpoint_model[k]

    # # re-map keys due to name change
    rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k]
    for k in rpe_mlp_keys:
        checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k)

    # delete attn_mask since we always re-init it
    attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
    for k in attn_mask_keys:
        del checkpoint_model[k]

    encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")]
    for k in encoder_keys:
        checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k)

    return checkpoint_model


def add_padding_metas(out, image_metas):
    device = out.device
    # left, right, top, bottom
    paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas]
    paddings = torch.stack(paddings).to(device)
    outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)]
    return torch.stack(outs)


def remove_padding(out, paddings):
    B, C, H, W = out.shape
    device = out.device
    # left, right, top, bottom
    paddings = torch.stack(paddings).to(device)
    outs = [
        o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]]
        for padding, o in zip(paddings, out)
    ]
    return torch.stack(outs)


def remove_padding_metas(out, image_metas):
    # left, right, top, bottom
    paddings = [
        torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas
    ]
    return remove_padding(out, paddings)


def ssi_helper(tensor1, tensor2):
    stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
    tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1)
    scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
        tensor2_one.T @ tensor1.unsqueeze(1)
    )
    scale, shift = scale_shift.squeeze().chunk(2, dim=0)
    return scale, shift


def calculate_mean_values(names, values):
    # Create a defaultdict to store sum and count for each name
    name_values = {name: {} for name in names}

    # Iterate through the lists and accumulate values for each name
    for name, value in zip(names, values):
        name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value
        name_values[name]["count"] = name_values[name].get("count", 0.0) + 1

    # Calculate mean values and create the output dictionary
    output_dict = {
        name: name_values[name]["sum"] / name_values[name]["count"]
        for name in name_values
    }

    return output_dict


def remove_leading_dim(infos):
    if isinstance(infos, dict):
        return {k: remove_leading_dim(v) for k, v in infos.items()}
    elif isinstance(infos, torch.Tensor):
        return infos.squeeze(0)
    else:
        return infos


def to_cpu(infos):
    if isinstance(infos, dict):
        return {k: to_cpu(v) for k, v in infos.items()}
    elif isinstance(infos, torch.Tensor):
        return infos.detach()
    else:
        return infos
