
import math

import torch
from timm.data import Mixup
from timm.loss import BinaryCrossEntropy, SoftTargetCrossEntropy
from timm.models.layers import drop
from timm.models.resnet import ResNet

from .convnext_official import ConvNeXt


def convnext_get_layer_id_and_scale_exp(self: ConvNeXt, para_name: str):
    N = 12 if len(self.stages[-2]) > 9 else 6
    if para_name.startswith("downsample_layers"):
        stage_id = int(para_name.split('.')[1])
        if stage_id == 0:
            layer_id = 0
        elif stage_id == 1 or stage_id == 2:
            layer_id = stage_id + 1
        else:  # stage_id == 3:
            layer_id = N
    elif para_name.startswith("stages"):
        stage_id = int(para_name.split('.')[1])
        block_id = int(para_name.split('.')[2])
        if stage_id == 0 or stage_id == 1:
            layer_id = stage_id + 1
        elif stage_id == 2:
            layer_id = 3 + block_id // 3
        else:  # stage_id == 3:
            layer_id = N
    else:
        layer_id = N + 1  # after backbone
    
    return layer_id, N + 1 - layer_id


def resnets_get_layer_id_and_scale_exp(self: ResNet, para_name: str):
    # stages:
    # 50  :    [3, 4, 6, 3]
    # 101 :    [3, 4, 23, 3]
    # 152 :    [3, 8, 36, 3]
    # 200 :    [3, 24, 36, 3]
    # eca269d: [3, 30, 48, 8]
    
    L2, L3 = len(self.layer2), len(self.layer3)
    if L2 == 4 and L3 == 6:
        blk2, blk3 = 2, 3
    elif L2 == 4 and L3 == 23:
        blk2, blk3 = 2, 3
    elif L2 == 8 and L3 == 36:
        blk2, blk3 = 4, 4
    elif L2 == 24 and L3 == 36:
        blk2, blk3 = 4, 4
    elif L2 == 30 and L3 == 48:
        blk2, blk3 = 5, 6
    else:
        raise NotImplementedError
    
    N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5)
    N = 2 + N2 + N3
    if para_name.startswith('layer'):  # 1, 2, 3, 4, 5
        stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1])
        if stage_id == 1:
            layer_id = 1
        elif stage_id == 2:
            layer_id = 2 + block_id // blk2  # 2, 3
        elif stage_id == 3:
            layer_id = 2 + N2 + block_id // blk3  # r50: 4, 5    r101: 4, 5, ..., 11
        else:  # == 4
            layer_id = N  # r50: 6       r101: 12
    elif para_name.startswith('fc.'):
        layer_id = N + 1  # r50: 7       r101: 13
    else:
        layer_id = 0
    
    return layer_id, N + 1 - layer_id  # r50: 0-7, 7-0   r101: 0-13, 13-0


def _ex_repr(self):
    return ', '.join(
        f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
        for k, v in vars(self).items()
        if not k.startswith('_') and k != 'training'
        and not isinstance(v, (torch.nn.Module, torch.Tensor))
    )


# IMPORTANT: update some member functions
__UPDATED = False
if not __UPDATED:
    for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, BinaryCrossEntropy, Mixup, drop.DropPath):
        if hasattr(clz, 'extra_repr'):
            clz.extra_repr = _ex_repr
        else:
            clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
    ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp
    ConvNeXt.get_layer_id_and_scale_exp = convnext_get_layer_id_and_scale_exp
    __UPDATED = True
