import json
import os
import time
from types import SimpleNamespace

import numpy as np
from warmup.xml_factory import *


# TODO remove redundancy from loaders and put into unified dicts,
# you then only load those ones and do the rest model agnostic


def create_log_dict(*args):
    log_dict = {}
    for id_var, var in enumerate(args):
        for id_el, el in enumerate(var):
            log_dict[f"var_{id_var}_el_{id_el}"] = el
    return log_dict


def write_xml(xml, identifier=0):
    xml_path = f"xml_factory/generated_xmls/model_{identifier}_{time.time_ns()}.xml"
    path = os.path.join(os.path.dirname(os.path.dirname(__file__)), xml_path)
    with open(path, "w") as f:
        f.write(xml)
    return path


def create_xml(args, identifier):
    models = {
        "arm26": load_arm26_xml,
        "arm750": load_arm750_xml,
        "arm12": load_arm12_xml,
        "leg10dof18m": load_leg10dof18m,
    }
    xml = models[args.name[:-3]](args)
    return write_xml(xml, identifier)


def get_model_kwargs_arm26(args):
    model_kwargs = {
        "timestep": args.timestep,
        "actuator": ARM26_MUSCLE_ACTUATOR
        if args.actuator == "muscle"
        else ARM26_MOTOR_ACTUATOR,
        "actuator_default": ARM26_MUSCLE_DEFAULT
        if args.actuator == "muscle"
        else ARM26_MOTOR_DEFAULT,
        "tendons": ARM26_MUSCLE_TENDON if args.actuator == "muscle" else "",
    }
    return model_kwargs


def get_model_kwargs_arm12(args):
    model_kwargs = {
        "timestep": args.timestep,
        "actuator": ARM12_MUSCLE_ACTUATOR
        if args.actuator == "muscle"
        else ARM12_MOTOR_ACTUATOR,
        "actuator_default": ARM12_MUSCLE_DEFAULT
        if args.actuator == "muscle"
        else ARM12_MOTOR_DEFAULT,
        "tendons": ARM12_MUSCLE_TENDON if args.actuator == "muscle" else "",
    }
    return model_kwargs


def get_model_kwargs_leg10dof18m(args):
    model_kwargs = {
        "timestep": args.timestep,
        "actuator": LEG10DOF18M_MUSCLE_ACTUATOR
        if args.actuator == "muscle"
        else LEG10DOF18M_MOTOR_ACTUATOR,
        "actuator_default": LEG10DOF18M_MUSCLE_DEFAULT
        if args.actuator == "muscle"
        else LEG10DOF18M_MOTOR_DEFAULT,
        "tendons": LEG10DOF18M_MUSCLE_TENDON if args.actuator == "muscle" else "",
    }
    return model_kwargs


def modify_actuator_kwargs_leg10dof18m(args, model_kwargs):
    if args.actuator == "muscle":
        dyntype = get_dyntype(args)
        actuator_kwargs = {
            "gaintype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "biastype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "dyntype": dyntype,
            "dynprm_act": str(args.dyn_time[0]),
            "dynprm_deact": str(args.dyn_time[1]),
        }

        model_kwargs["actuator_default"] = model_kwargs["actuator_default"].format(
            **actuator_kwargs
        )
    else:
        raise Exception(f"Invalid actuator chosen: {args.actuator}")


def get_dyntype(args):
    if args.muscle_morphology == "mujoco":
        return '"muscle"'
    if args.muscle_morphology == "user":
        if args.morph_settings[2] == 1:
            return '"muscle"'
        elif args.morph_settings[2] == 2:
            return '"user"'
        else:
            return '"none"'


def modify_actuator_kwargs_arm12(args, model_kwargs):
    if args.actuator == "muscle":
        dyntype = get_dyntype(args)
        actuator_kwargs = {
            "gaintype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "biastype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "dyntype": dyntype,
            "dynprm_act": str(args.dyn_time[0]),
            "dynprm_deact": str(args.dyn_time[1]),
        }
        if args.monoarticular_muscle:
            model_kwargs["actuator"] = ARM12_MUSCLE_ACTUATOR_MONO
            model_kwargs["tendons"] = ARM12_MUSCLE_TENDON_MONO

        model_kwargs["actuator_default"] = model_kwargs["actuator_default"].format(
            **actuator_kwargs
        )
    elif args.actuator == "motor":
        if args.antagonistic_motor:
            model_kwargs["actuator"] = ARM12_MOTOR_ACTUATOR_DOUBLE
    else:
        raise Exception(f"Invalid actuator chosen: {args.actuator}")


def modify_actuator_kwargs_arm26(args, model_kwargs):
    if args.actuator == "muscle":
        dyntype = get_dyntype(args)
        actuator_kwargs = {
            "gaintype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "biastype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "dyntype": dyntype,
            "dynprm_act": str(args.dyn_time[0]),
            "dynprm_deact": str(args.dyn_time[1]),
        }
        if args.monoarticular_muscle:
            model_kwargs["actuator"] = ARM26_MUSCLE_ACTUATOR_MONO
            model_kwargs["tendons"] = ARM26_MUSCLE_TENDON_MONO

        model_kwargs["actuator_default"] = model_kwargs["actuator_default"].format(
            **actuator_kwargs
        )
    elif args.actuator == "motor":
        if args.antagonistic_motor:
            model_kwargs["actuator"] = ARM26_MOTOR_ACTUATOR_DOUBLE
    else:
        raise Exception(f"Invalid actuator chosen: {args.actuator}")


def load_arm26_xml(args):
    model_kwargs = get_model_kwargs_arm26(args)
    modify_actuator_kwargs_arm26(args, model_kwargs)
    return ARM26.format(**model_kwargs)


def load_arm12_xml(args):
    model_kwargs = get_model_kwargs_arm12(args)
    modify_actuator_kwargs_arm12(args, model_kwargs)
    return ARM12.format(**model_kwargs)


def load_leg10dof18m(args):
    model_kwargs = get_model_kwargs_leg10dof18m(args)
    modify_actuator_kwargs_leg10dof18m(args, model_kwargs)
    return LEG10DOF18M.format(**model_kwargs)


def load_arm750_xml(args):
    model_kwargs = get_model_kwargs_arm750(args)
    modify_actuator_kwargs_arm750(args, model_kwargs)
    return ARM750.format(**model_kwargs)


def get_model_kwargs_arm750(args):
    model_kwargs = {
        "timestep": args.timestep,
        "actuator": ARM750_MUSCLE_ACTUATOR
        if args.actuator == "muscle"
        else ARM750_MOTOR_ACTUATOR,
        "actuator_default": ARM750_MUSCLE_DEFAULT
        if args.actuator == "muscle"
        else ARM750_MOTOR_DEFAULT,
        "tendons": ARM750_MUSCLE_TENDON if args.actuator == "muscle" else "",
    }
    return model_kwargs


def modify_actuator_kwargs_arm750(args, model_kwargs):
    if args.actuator == "muscle":
        dyntype = get_dyntype(args)
        actuator_kwargs = {
            "gaintype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "biastype": '"muscle"' if args.muscle_morphology == "mujoco" else '"user"',
            "dyntype": dyntype,
            "dynprm_act": str(args.dyn_time[0]),
            "dynprm_deact": str(args.dyn_time[1]),
        }
        model_kwargs["actuator_default"] = model_kwargs["actuator_default"].format(
            **actuator_kwargs
        )
    elif args.actuator == "motor":
        if args.antagonistic_motor:
            model_kwargs["actuator"] = ARM750_MOTOR_ACTUATOR_DOUBLE
    else:
        raise Exception(f"Invalid actuator chosen: {args.actuator}")


def load_leg69_xml(args):
    raise NotImplementedError


def load_default_params(path=None):
    if path is None:
        path = "../param_files/default_params.json"
    with open(path) as f:
        params = json.load(f)
    params["model_dir"] = "."
    params = SimpleNamespace(**params)
    args = SimpleNamespace(**params.params_env)
    return params, args


def save_metrics(params, args, metrics):
    paths = [
        f"{params.model_dir}/{args.folder_name}/{key}.npy" for key in metrics.keys()
    ]
    for key, path in zip(metrics.keys(), paths):
        np.save(f"{params.model_dir}/{args.folder_name}/{key}.npy", metrics[key])
    print("Saving metrics")


class DummyLogger:
    def __init__(self):
        pass

    def log_data(self, *args, **kwargs):
        pass

    def write_separate(self, *args, **kwargs):
        pass

    def reset(self, *args, **kwargs):
        pass
