import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18

"""
Taken and modified from: XXXX
"""
# @register_model_trunk("eshednet")
class EshedNet(nn.Module):
    def __init__(self):
        """Create a new EshedNet

        Inputs:
            model_config: an AttrDict (like a dictionary, but with dot syntax support)
                that specifies the parameters for the model trunk. Specifically, we will
                expect that "model_config.TRUNK.TRUNK_PARAMS.position_dir" exists

            model_name: VISSL will pass the model name as the second arg, but we don't
                use it for anything in this case
        """
        super(EshedNet, self).__init__()

        # self.positions = self._load_positions(model_config.TRUNK.TRUNK_PARAMS.position_dir)
        self.base_model = resnet18(weights=None)

        # remove the FC layer, we're not going to need it
        self.base_model.fc = nn.Identity()

    # VISSL requires this signature for forward passes
    def forward(self, x: torch.Tensor):
        x = self.base_model.conv1(x)
        x = self.base_model.bn1(x)
        x = self.base_model.relu(x)
        maxpool = self.base_model.maxpool(x)

        x_1_0 = self.base_model.layer1[0](maxpool)
        x_1_1 = self.base_model.layer1[1](x_1_0)
        x_2_0 = self.base_model.layer2[0](x_1_1)
        x_2_1 = self.base_model.layer2[1](x_2_0)
        x_3_0 = self.base_model.layer3[0](x_2_1)
        x_3_1 = self.base_model.layer3[1](x_3_0)
        x_4_0 = self.base_model.layer4[0](x_3_1)
        x_4_1 = self.base_model.layer4[1](x_4_0)

        x = self.base_model.avgpool(x_4_1)
        flat_outputs = torch.flatten(x, 1)
        return flat_outputs


def load_eshed_checkpoint(filename: str):
    tdann_model = EshedNet()

    """
    if not strict, it throws an error for unexpected keys:
    Unexpected key(s) in state_dict: "_feature_blocks.conv1.weight",...
    """
    state_dict = torch.load(filename, map_location="cpu", weights_only=True)["classy_state_dict"][
        "base_model"
    ]["model"]["trunk"]
    tdann_model.load_state_dict(state_dict, strict=False)
    return tdann_model
