import os

import torch.nn

from .model import model as module_arch
from .parse_config import ConfigParser


class Dummy_Arg:
    def __init__(self, config=None, resume=None, device=None):
        self.config, self.resume, self.device = config, resume, device

    def parse_args(self, **kwargs):
        return self

def get_abs_config_path(cfg_r_p):
    current_path = os.path.abspath(__file__)
    directory_path = os.path.dirname(current_path)
    config_abs_path = os.path.join(directory_path, cfg_r_p)
    return config_abs_path

# default use cifar100-im100

class ride_net_wrapped(torch.nn.Module):
    def __init__(self, build_func, *args, **kwargs):
        super().__init__()
        self.model = eval(build_func)(*args, **kwargs)


    def forward(self, x, target=None):
        logits, _ = self.model(x, target)
        return logits

def build_ride_res32_cifar100(model_path):
    config = './configs/config_imbalance_cifar100_ride_ea.json'
    config_abs_path = get_abs_config_path(config)
    args = Dummy_Arg(config=config_abs_path)
    config = ConfigParser.from_args(args)
    # build model architecture
    if 'returns_feat' in config['arch']['args']:
        model = config.init_obj('arch', module_arch, allow_override=True, returns_feat=False)
    else:
        model = config.init_obj('arch', module_arch)
    state = torch.load(model_path)['state_dict']
    model.load_state_dict(state)
    return model


def build_ride_resx50_imagenet_lt():
    config = './configs/config_imagenet_lt_resnext50_ride_ea.json'
    config_abs_path = get_abs_config_path(config)
    args = Dummy_Arg(config=config_abs_path)
    config = ConfigParser.from_args(args)
    # build model architecture
    if 'returns_feat' in config['arch']['args']:
        model = config.init_obj('arch', module_arch, allow_override=True, returns_feat=False)
    else:
        model = config.init_obj('arch', module_arch)

    return model


if __name__ == '__main__':
    build_ride_resx50_imagenet_lt()
    build_ride_res32_cifar100()

