import os
from typing import Any, Dict, Optional

import requests
import torch
import torch.hub
import torchvision.models
from escnn import gspaces
from torchvision.datasets.utils import download_file_from_google_drive

from .e2resnet import E2BasicBlock, E2BottleNeck, E2ResNet

WEIGHT_FILE_IDS = {
    "d1resnet18": "1LX9--04ZOTv28kZH2WIO8IhL8UI7Wi0f",
    "c4resnet18": "1hO04WpgJHH_a0f2eYfClhwBm4SRnC9xM",
    "d4resnet18": "19TsJP49g6O16eGihP35Cg5IPXGoNgVaW",
    "c8resnet18": "1i4uboCtvyYkhWOqwOAg2A57jb-D8-xCN",
    "d1resnet50": "1q6mep0tpIoiZFYWuSi1dPnVQ1Fd60OKn",
    "c4resnet50": "1NYTjon1zvghdGmpn4OkbB4xhIX5ixAxI",
    "d4resnet50": "1Fr3JQqQFGaL_JjPelZ3gxGhUs5_o0lI8",
    "c8resnet50": "13Et3SvIoxRFEy9N8t61O6EeAeameKzLX",
    "d1resnet101": "1iRRkAM3JgU0L61YO3LC3zGFFbK0F3f1I",
    "c4resnet101": "16N9H6ac_WWzC01wBDW06tTL0HWCTcVmW",
    "d4resnet101": "1qmkLJV87lVKFnPdZMEsYoe2rrjpI96PM",
}


def load_state_dict_from_google_drive(
    file_id: str, model_dir: Optional[str] = None
) -> Dict[str, Any]:
    # based off https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url
    hub_dir = torch.hub.get_dir()
    model_dir = os.path.join(hub_dir, "checkpoints")
    os.makedirs(model_dir, exist_ok=True)

    cached_file = os.path.join(model_dir, file_id)
    if not os.path.exists(cached_file):
        download_file_from_google_drive(file_id, root=model_dir)

    return torch.load(cached_file)


# pytorch versions, with same interface for easy loading during training
def resnet18(*args, **kwargs):
    model = torchvision.models.resnet18()
    model.name = "resnet18"
    return model


def resnet50(*args, **kwargs):
    model = torchvision.models.resnet50()
    model.name = "resnet50"
    return model


def resnet101(*args, **kwargs):
    model = torchvision.models.resnet101()
    model.name = "resnet101"
    return model



def c4resnet18(
    pretrained: bool = False, initialize: bool = True, fixed_params: bool = True,
    use_gpool: bool = False
):
    # if loading pretrained weights, can skip initialization
    initialize = False if pretrained else initialize

    base_width = 39
    model = E2ResNet(
        gspace=gspaces.rot2dOnR2(N=4),
        block=E2BasicBlock,
        layers=[2, 2, 2, 2],
        num_classes=1000,
        base_width=base_width if fixed_params else 16,
        # base_width=34 if fixed_params else 16,
        initialize=initialize,
        use_gpool=use_gpool
    )
    model.name = "c4resnet18"
    if not fixed_params:
        model.name += "-fast"
    model.order = 4
    model.num_out_regular_repr = base_width * 8

    if pretrained:
        state_dict = load_state_dict_from_google_drive(WEIGHT_FILE_IDS[model.name])
        model.load_state_dict(state_dict, strict=False)

    return model


def c4resnet50(
    pretrained: bool = False, initialize: bool = True, fixed_params: bool = True,
    use_gpool: bool = False
):
    # if loading pretrained weights, can skip initialization
    initialize = False if pretrained else initialize

    base_width = 20
    model = E2ResNet(
        gspace=gspaces.rot2dOnR2(N=4),
        block=E2BottleNeck,
        layers=[3, 4, 6, 3],
        num_classes=1000,
        base_width=base_width if fixed_params else 16,
        initialize=initialize,
        freq_cutoff=True,
        use_gpool=use_gpool
    )
    model.name = "c4resnet50"
    if not fixed_params:
        model.name += "-fast"
    model.order = 4
    model.num_out_regular_repr = base_width * 8 * E2BottleNeck.expansion

    if pretrained:
        state_dict = load_state_dict_from_google_drive(WEIGHT_FILE_IDS[model.name])
        model.load_state_dict(state_dict, strict=False)

    return model

