import yaml
import torch
import torch.nn as nn
from torchvision.models.resnet import _resnet, Bottleneck
from utils import count_model_layers, count_model_params, count_flops
from psychonet import PsychoNet


# a simple config class and methods to build models from them
class Config:
    ARCH = None
    MODEL = None

    def __init__(self, fields: dict):
        self.fields = fields

        # some fun basic recursion :)
        for k, v in fields.items():
            if isinstance(v, dict):
                setattr(self, k, Config(v))
            else:
                setattr(self, k, v)

    @classmethod
    def from_yaml(cls, file_path):
        with open(file_path, 'r') as file:
            fields = yaml.safe_load(file)
        return cls(fields)

    def __str__(self, level=0):
        output = ''
        indent = '  ' * level
        for k, v in self.fields.items():
            if isinstance(v, dict):
                output += f'{indent}{k}:\n{Config(v).__str__(level + 1)}'
            else:
                output += f'{indent}{k}: {v}\n'
        return output

# other models
class ResNetWrapper(nn.Module):
    """
    Wrapper for torchhub ResNet that allows for changing number of output classes.

    """

    def __init__(self,
                 repo_or_dir=None,
                 arch=None,
                 n_class=1000):
        """


        Args:
            repo_or_dir: repo or directory to load the model from
            arch: architecture name, 'resnet50', 'resnet101', etc.
            n_class: number of output classes
        """

        super().__init__()
        if repo_or_dir is None:
            repo_or_dir = 'pytorch/vision:v0.10.0'
        if arch is None:
            arch = 'resnet50'

        if arch in ['resnet50', 'resnet101', 'resnet152']:
            self.model = torch.hub.load(repo_or_dir, arch, pretrained=False)
            # change the output layer
            self.model.fc = nn.Linear(self.model.fc.in_features, n_class)
        elif arch == 'resnet270':
            self.model = _resnet(Bottleneck, [4, 29, 53, 4],
                                 weights=None,
                                 progress=False)
            # change the output layer
            self.model.fc = nn.Linear(self.model.fc.in_features, n_class)

    def forward(self, x):
        return self.model(x)

    @classmethod
    def from_cfg(cls, cfg):
        return cls(repo_or_dir=getattr(cfg.MODEL, 'REPO_OR_DIR', None),
                   arch=cfg.MODEL.ARCH,
                   n_class=cfg.MODEL.N_CLASS)

def build_model(cfg_pth, verbose=True, ret_cfg=False):
    """
    Builds a model from a config file

    Args:
        cfg_pth: path to the config file
        verbose: whether to print the config file
        ret_cfg: whether to return the config file as well

    Returns:
        model: the loaded torch.Module model
        cfg: the config file (optional)

    """
    cfg = Config.from_yaml(cfg_pth)
    if verbose:
        print(f'\nCONFIG FILE:\n{cfg}\n')

    arch = cfg.ARCH
    if arch == 'resnet':
        model = ResNetWrapper.from_cfg(cfg)
    elif arch == 'psychonet':
        model = PsychoNet.from_cfg(cfg)
    else:
        raise NotImplementedError(f'Architecture {arch} not implemented')

    return (model, cfg) if ret_cfg else model


if __name__ == '__main__':
    # EXAMPLE: load Psycho-S
    model = build_model('configs/psycho_s.yaml')
    print(model)

    # count parameters, layers and FLOPs
    print(f"{count_model_params(model) / 1e6:.4f}M parameters")
    print(count_model_layers(model))
    print(f"{count_flops(model, (1, 3, 224, 224)) / 1e9:.4f}G FLOPS")
