import os
import subprocess
import sys
from copy import deepcopy
from pathlib import Path

import ase.io
import numpy as np
import pytest
from ase.atoms import Atoms

from mace.calculators.mace import MACECalculator
from mace.cli.run_train import run as run_mace_train
from mace.data.utils import KeySpecification
from mace.tools import build_default_arg_parser

run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py"


_mace_params = {
    "name": "MACE",
    "valid_fraction": 0.05,
    "energy_weight": 1.0,
    "forces_weight": 10.0,
    "stress_weight": 1.0,
    "model": "MACE",
    "hidden_irreps": "128x0e",
    "max_num_epochs": 10,
    "swa": None,
    "start_swa": 5,
    "ema": None,
    "ema_decay": 0.99,
    "amsgrad": None,
    "device": "cpu",
    "seed": 5,
    "loss": "weighted",
    "energy_key": "REF_energy",
    "forces_key": "REF_forces",
    "stress_key": "REF_stress",
    "interaction_first": "RealAgnosticResidualInteractionBlock",
    "batch_size": 1,
    "valid_batch_size": 1,
    "num_samples_pt": 50,
    "subselect_pt": "random",
    "eval_interval": 2,
    "num_radial_basis": 10,
    "r_max": 6.0,
    "default_dtype": "float64",
}


def configs_numbered_keys():
    np.random.seed(0)
    water = Atoms(
        numbers=[8, 1, 1],
        positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]],
        cell=[4] * 3,
        pbc=[True] * 3,
    )

    energies = list(np.random.normal(0.1, size=15))
    forces = list(np.random.normal(0.1, size=(15, 3, 3)))

    trial_configs_lists = []
    # some keys present, some not
    keys_to_use = (
        ["REF_energy"]
        + ["2_energy"] * 2
        + ["3_energy"] * 3
        + ["4_energy"] * 4
        + ["5_energy"] * 5
    )

    force_keys_to_use = (
        ["REF_forces"]
        + ["2_forces"] * 2
        + ["3_forces"] * 3
        + ["4_forces"] * 4
        + ["5_forces"] * 5
    )

    for ind in range(15):
        c = deepcopy(water)
        c.info[keys_to_use[ind]] = energies[ind]
        c.arrays[force_keys_to_use[ind]] = forces[ind]
        c.positions += np.random.normal(0.1, size=(3, 3))
        trial_configs_lists.append(c)

    return trial_configs_lists


def trial_yamls_and_and_expected():
    yamls = {}
    command_line_kwargs = {"energy_key": "2_energy", "forces_key": "2_forces"}

    yamls["no_heads"] = {}

    yamls["one_head_no_dicts"] = {
        "heads": {
            "Default": {
                "energy_key": "3_energy",
            }
        }
    }

    yamls["one_head_with_dicts"] = {
        "heads": {
            "Default": {
                "info_keys": {
                    "energy": "3_energy",
                },
                "arrays_keys": {
                    "forces": "3_forces",
                },
            }
        }
    }

    yamls["two_heads_no_dicts"] = {
        "heads": {
            "dft": {
                "train_file": "fit_multihead_dft.xyz",
                "energy_key": "3_energy",
            },
            "mp2": {
                "train_file": "fit_multihead_mp2.xyz",
                "energy_key": "4_energy",
            },
        }
    }

    yamls["two_heads_mixed"] = {
        "heads": {
            "dft": {
                "train_file": "fit_multihead_dft.xyz",
                "info_keys": {
                    "energy": "3_energy",
                },
                "arrays_keys": {
                    "forces": "3_forces",
                },
                "forces_key": "4_forces",
            },
            "mp2": {
                "train_file": "fit_multihead_mp2.xyz",
                "energy_key": "4_energy",
            },
        }
    }
    all_arg_sets = {
        "with_command_line": {
            key: {**command_line_kwargs, **value} for key, value in yamls.items()
        },
        "without_command_line": yamls,
    }

    all_expected_outputs = {
        "with_command_line": {
            "no_heads": [
                1.0037831178668188,
                1.0183291323603265,
                1.0120784084221528,
                0.9935695881012243,
                1.0021641561865526,
                0.9999135609205868,
                0.9809440616323108,
                1.0025784765050076,
                1.0017901145495376,
                1.0136913185404515,
                1.006798563238269,
                1.0187758397828384,
                1.0180201540775071,
                1.0132368725061702,
                0.9998734173248169,
            ],
            "one_head_no_dicts": [
                1.0028437510688613,
                1.0514693378041775,
                1.059933403321331,
                1.034719940573569,
                1.0438040675561824,
                1.019719477728329,
                0.9841759692947915,
                1.0435266573857496,
                1.0339501989779065,
                1.0501795448530264,
                1.0402594216704781,
                1.0604998765679152,
                1.0633411200246015,
                1.0539071190201297,
                1.0393496428177804,
            ],
            "one_head_with_dicts": [
                0.8638341551096959,
                1.0078341354784144,
                1.0149701178418595,
                0.9945723048460148,
                1.0184158011731292,
                0.9992135295205004,
                0.8943420783639198,
                1.0327920054084088,
                0.9905731198078909,
                0.9838325204450648,
                1.0018725575620482,
                1.007263052421034,
                1.0335213929231966,
                1.0033503312511205,
                1.0174433894759563,
            ],
            "two_heads_no_dicts": [
                0.9836377578288774,
                1.0196844186291318,
                1.0151628222871238,
                0.957307281711648,
                0.985574141310865,
                0.9629670134047853,
                0.9242583185138095,
                0.9807770070311039,
                0.9973679440479541,
                1.0221127246963275,
                1.0031807967874216,
                1.0358701219543687,
                1.0434208761164758,
                1.0235606028124515,
                0.9797494630655053,
            ],
            "two_heads_mixed": [
                0.8664108574741868,
                0.9907166576278023,
                1.0051969372365164,
                0.978702477000018,
                1.025500166764692,
                0.9940095566375018,
                0.9034029726954119,
                1.0391739502744488,
                0.9717327061183668,
                0.972292103670355,
                1.0012510461663253,
                0.9978051155885286,
                1.0378611651753475,
                1.0003207628186224,
                1.0209509292189651,
            ],
        },
        "without_command_line": {
            "no_heads": [
                0.9352605307451007,
                0.991084559389268,
                0.9940350095024881,
                0.9953849198103668,
                0.9954705498032904,
                0.9964815693808411,
                0.9663142667436776,
                0.9947223808739147,
                0.9897776682803257,
                0.989027769690667,
                0.9910280920241263,
                0.992067980667518,
                0.9917276132506404,
                0.9902848752169671,
                0.9928585982942544,
            ],
            "one_head_no_dicts": [
                0.9425342207393083,
                1.0149788456087416,
                1.0249228965652788,
                1.0247924743285792,
                1.02732103964481,
                1.0168852937950326,
                0.9771283495170653,
                1.0261776335561517,
                1.0130461033368028,
                1.0162619153561783,
                1.019995179866916,
                1.0209512298344965,
                1.0219971755636952,
                1.0195791901659124,
                1.0234662527729408,
            ],
            "one_head_with_dicts": [
                0.8638341551096959,
                1.0078341354784144,
                1.0149701178418595,
                0.9945723048460148,
                1.0184158011731292,
                0.9992135295205004,
                0.8943420783639198,
                1.0327920054084088,
                0.9905731198078909,
                0.9838325204450648,
                1.0018725575620482,
                1.007263052421034,
                1.0335213929231966,
                1.0033503312511205,
                1.0174433894759563,
            ],
            "two_heads_no_dicts": [
                0.9933763730233168,
                0.9986480398559268,
                1.0042486164355315,
                1.0025568793877726,
                1.0032598081704625,
                0.9926714183717912,
                0.9920385249670881,
                1.0020278841030676,
                1.0012474150830537,
                1.0039289677261019,
                1.0022718878661814,
                1.003586385624809,
                1.003436450009097,
                1.003805673887942,
                1.001450261102316,
            ],
            "two_heads_mixed": [
                0.8781767864616707,
                0.9843563603794138,
                1.0145197579049248,
                0.9835060778675391,
                1.0419060462994596,
                0.9917393978520056,
                0.9091521032773944,
                1.0605463095070453,
                0.9685381713826684,
                0.9866493058823766,
                1.00305061187164,
                1.0051273128414386,
                1.037964258398104,
                1.0106663924241408,
                1.0274351814133602,
            ],
        },
    }

    list_of_all = []
    for key, value in all_arg_sets.items():
        for key2, value2 in value.items():
            list_of_all.append(
                (value2, (key, key2), np.asarray(all_expected_outputs[key][key2]))
            )

    return list_of_all


def dict_to_yaml_str(data, indent=0):
    yaml_str = ""
    for key, value in data.items():
        yaml_str += " " * indent + str(key) + ":"
        if isinstance(value, dict):
            yaml_str += "\n" + dict_to_yaml_str(value, indent + 2)
        else:
            yaml_str += " " + str(value) + "\n"
    return yaml_str


_trial_yamls_and_and_expected = trial_yamls_and_and_expected()


@pytest.mark.parametrize(
    "yaml_contents, name, expected_value", _trial_yamls_and_and_expected
)
def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value):
    fitting_configs = configs_numbered_keys()

    ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs)
    ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs)
    ase.io.write(tmp_path / "duplicated_fit_multihead_dft.xyz", fitting_configs)

    mace_params = _mace_params.copy()
    mace_params["valid_fraction"] = 0.1
    mace_params["checkpoints_dir"] = str(tmp_path)
    mace_params["model_dir"] = str(tmp_path)
    mace_params["train_file"] = "fit_multihead_dft.xyz"
    mace_params["E0s"] = "{1:0.0,8:1.0}"
    mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz"
    del mace_params["valid_fraction"]
    mace_params["max_num_epochs"] = 1  # many tests to do
    del mace_params["energy_key"]
    del mace_params["forces_key"]
    del mace_params["stress_key"]

    mace_params["name"] = "MACE_"

    filename = tmp_path / "config.yaml"
    with open(filename, "w", encoding="utf-8") as file:
        file.write(dict_to_yaml_str(yaml_contents))
    if len(yaml_contents) > 0:
        mace_params["config"] = str(tmp_path / "config.yaml")

    run_env = os.environ.copy()
    sys.path.insert(0, str(Path(__file__).parent.parent))
    run_env["PYTHONPATH"] = ":".join(sys.path)
    print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"])

    cmd = (
        sys.executable
        + " "
        + str(run_train)
        + " "
        + " ".join(
            [
                (f"--{k}={v}" if v is not None else f"--{k}")
                for k, v in mace_params.items()
            ]
        )
    )

    p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True)
    assert p.returncode == 0

    if "heads" in yaml_contents:
        headname = list(yaml_contents["heads"].keys())[0]
    else:
        headname = "Default"

    calc = MACECalculator(
        tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname
    )

    Es = []
    for at in fitting_configs:
        at.calc = calc
        Es.append(at.get_potential_energy())

    print(name)
    print("Es", Es)

    assert np.allclose(
        np.asarray(Es), expected_value, rtol=1e-8, atol=1e-8
    ), f"Expected {expected_value} but got {Es} with error {np.max(np.abs(Es - expected_value))}"


def test_multihead_finetuning_does_not_modify_default_keyspec(tmp_path):
    fitting_configs = configs_numbered_keys()
    ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs)

    args = build_default_arg_parser().parse_args(
        [
            "--name",
            "_MACE_",
            "--train_file",
            str(tmp_path / "fit_multihead_dft.xyz"),
            "--foundation_model",
            "small",
            "--device",
            "cpu",
            "--E0s",
            "{1:0.0,8:1.0}",
            "--energy_key",
            "2_energy",
            "--dry_run",
        ]
    )
    default_key_spec = KeySpecification.from_defaults()
    default_key_spec.info_keys["energy"] = "2_energy"
    run_mace_train(args)
    assert args.key_specification == default_key_spec


# for creating values
def make_output():
    outputs = {}
    for yaml_contents, name, expected_value in _trial_yamls_and_and_expected:
        if name[0] not in outputs:
            outputs[name[0]] = {}
        expected = test_key_specification_methods(
            Path("."), yaml_contents, name, expected_value, debug_test=False
        )
        outputs[name[0]][name[1]] = expected
    print(outputs)
