import json
import torch
from models.cql_model import CQLModel
from models.iql_model import IQL_Evaluator, IQL_Policy, PerTokenIQL, TopAdvantageNGrams
from models.tf_iql_model import TFILQL
from models.is_iql_model import IS_ILQL, IS_IQL_Policy, IS_IQL_Evaluator
from models.utterance_iql_model import PerUtteranceIQL, PerUtteranceIQL_Policy, UtteranceIQL_Evaluator
from data.rl_data import ConstantTokenReward, SepcifiedTokenReward
from utils.cache import Cache
from utils.misc import convert_path
from models.gpt2_optional_final_ln import GPT2LMHeadModel, GPT2Config, GPT2Model

registry = {}
cache = {}


def register(name):
    def add_f(f):
        registry[name] = f
        return f

    return add_f


def load_item(config, *args, verbose=True):
    config = config.copy()
    name = config.pop("name")
    if name not in registry:
        raise NotImplementedError
    if "cache_id" in config:
        if (name, config["cache_id"]) in cache:
            if verbose:
                print(f'loading from cache ({name}, {config["cache_id"]})')
            return cache[(name, config["cache_id"])]
    if verbose:
        print(f"loading {name}: {config}")
    item = registry[name](config, *args, verbose=verbose)
    if "cache_id" in config:
        print(f'saving to cache ({name}, {config["cache_id"]})')
        cache[(name, config["cache_id"])] = item
    return item


def load_model(config, model, device, verbose=True):
    model = model.to(device)
    if config["checkpoint_path"] is not None:
        if verbose:
            print("loading %s state dict from: %s" % (config["name"], convert_path(config["checkpoint_path"])))
        model.load_state_dict(
            torch.load(convert_path(config["checkpoint_path"]), map_location="cpu"), strict=config["strict_load"]
        )
        if verbose:
            print("loaded.")
    return model


@register("constant_token_reward")
def load_constant_token_reward(config, device, verbose=True):
    return ConstantTokenReward(config["c"])


@register("specified_token_reward")
def load_specified_token_reward(config, device, verbose=True):
    with open(convert_path(config["token_file"]), "r") as f:
        token_data = {int(k): v for k, v in json.load(f).items()}
    return SepcifiedTokenReward(token_data, config["scale"], config["shift"])


@register("gpt2")
def load_gpt2(config, verbose=True):
    obj = GPT2LMHeadModel if config["lm_head"] else GPT2Model
    if config["from_pretrained"]:
        return obj.from_pretrained(config["gpt2_type"])
    config = GPT2Config.from_pretrained(config["gpt2_type"])
    return obj(config)


@register("per_token_iql")
def load_per_token_iql(config, device, verbose=True):
    gpt2 = load_item(config["gpt2"], verbose=verbose)
    dataset = load_item(config["dataset"], device, verbose=verbose)
    model = PerTokenIQL(
        gpt2,
        dataset,
        device,
        config["alpha"],
        config["gamma"],
        config["beta"],
        config["transition_weight"],
        config["clip_weight"],
        config["value_max"],
        config["value_min"],
        config["detach_v"],
        config["detach_pi"],
        config["detach_q"],
        config["double_q"],
        config["tau"],
        config["seperate_policy"],
        config["seperate_target"],
        config["exp_weights"],
        config["dm_margin"],
        config["advanced_mlp"],
        config["cql_temp"],
    )
    return load_model(config["load"], model, device, verbose=verbose)


@register("is_ilql")
def load_is_ilql(config, device, verbose=True):
    gpt2 = load_item(config["gpt2"], verbose=verbose)
    dataset = load_item(config["dataset"], device, verbose=verbose)
    model = IS_ILQL(
        gpt2,
        dataset,
        device,
        config["alpha"],
        config["gamma"],
        config["beta"],
        config["transition_weight"],
        config["clip_weight"],
        config["value_max"],
        config["value_min"],
        config["detach_v"],
        config["detach_pi"],
        config["detach_q"],
        config["double_q"],
        config["tau"],
        config["seperate_policy"],
        config["seperate_target"],
        config["exp_weights"],
        config["dm_margin"],
        config["advanced_mlp"],
        config["cql_temp"],
        K=config["K"],
    )
    return load_model(config["load"], model, device, verbose=verbose)


@register("tf_ilqn")
def load_tf_ilqn(config, device, verbose=True):
    gpt2 = load_item(config["gpt2"], verbose=verbose)
    dataset = load_item(config["dataset"], device, verbose=verbose)
    model = TFILQL(
        gpt2,
        dataset,
        device,
        config["alpha"],
        config["gamma"],
        config["beta"],
        config["transition_weight"],
        config["clip_weight"],
        config["value_max"],
        config["value_min"],
        config["detach_v"],
        config["detach_pi"],
        config["detach_q"],
        config["double_q"],
        config["tau"],
        config["seperate_policy"],
        config["seperate_target"],
        config["exp_weights"],
        config["dm_margin"],
        config["advanced_mlp"],
        config["cql_temp"],
    )
    return load_model(config["load"], model, device, verbose=verbose)


@register("per_token_cql")
def load_per_token_cql(config, device, verbose=True):
    gpt2 = load_item(config["gpt2"], verbose=verbose)
    dataset = load_item(config["dataset"], device, verbose=verbose)
    model = CQLModel(
        gpt2,
        dataset,
        device,
        config["alpha"],
        config["gamma"],
        config["beta"],
        config["transition_weight"],
        config["clip_weight"],
        config["value_max"],
        config["value_min"],
        config["detach_v"],
        config["detach_pi"],
        config["detach_q"],
        config["double_q"],
        config["seperate_policy"],
        config["seperate_target"],
        config["exp_weights"],
        config["advanced_mlp"],
        config["cql_temp"],
    )
    return load_model(config["load"], model, device, verbose=verbose)


@register("per_utterance_iql")
def load_per_utterance_iql(config, device, verbose=True):
    gpt2 = load_item(config["gpt2"], verbose=verbose)
    dataset = load_item(config["dataset"], device, verbose=verbose)
    model = PerUtteranceIQL(
        gpt2,
        dataset,
        device,
        config["alpha"],
        config["gamma"],
        config["beta"],
        config["transition_weight"],
        config["clip_weight"],
        config["value_max"],
        config["value_min"],
        config["detach_v"],
        config["detach_pi"],
        config["detach_q"],
        config["double_q"],
        config["tau"],
        config["seperate_policy"],
        config["seperate_target"],
        config["exp_weights"],
        config["advanced_mlp"],
    )
    return load_model(config["load"], model, device, verbose=verbose)


@register("iql_policy")
def load_iql_policy(config, device, verbose=True):
    iql_model = load_item(config["iql_model"], device, verbose=verbose)
    return IQL_Policy(iql_model, config["kind"], **config["generation_kwargs"])


@register("is_iql_policy")
def load_is_iql_policy(config, device, verbose=True):
    iql_model = load_item(config["iql_model"], device, verbose=verbose)
    return IS_IQL_Policy(iql_model, config["kind"], **config["generation_kwargs"])


@register("utterance_iql_policy")
def load_utterance_iql_policy(config, device, verbose=True):
    iql_model = load_item(config["iql_model"], device, verbose=verbose)
    return PerUtteranceIQL_Policy(iql_model, config["kind"], **config["generation_kwargs"])


@register("iql_evaluator")
def load_iql_evaluator(config, device, verbose=True):
    env = load_item(config["env"], device, verbose=verbose)
    return IQL_Evaluator(env, config["verbose"], config["kind"], **config["generation_kwargs"])


@register("is_iql_evaluator")
def load_is_iql_evaluator(config, device, verbose=True):
    env = load_item(config["env"], device, verbose=verbose)
    return IS_IQL_Evaluator(env, config["verbose"], config["kind"], **config["generation_kwargs"])


@register("top_advantage_n_grams")
def load_top_advantage_n_grams(config, device, verbose=True):
    data = load_item(config["data"], device, verbose=verbose)
    return TopAdvantageNGrams(data, config["print_every"], config["print_k"], config["n_gram"])


@register("utterance_iql_evaluator")
def load_utterance_iql_evaluator(config, device, verbose=True):
    env = load_item(config["env"], device, verbose=verbose)
    return UtteranceIQL_Evaluator(env, config["verbose"], config["kind"], **config["generation_kwargs"])
