# pylint: disable=missing-function-docstring
import hashlib
import io
import json
import math
import os
import shutil
import tarfile
import tempfile
from pathlib import Path
from typing import Any
import click
import numpy as np
import zstandard as zstd
from tqdm.auto import tqdm
from tsbench.experiments import aws


def tar_directory(path: Path) -> bytes:
    buf = io.BytesIO()
    with tarfile.open(fileobj=buf, mode="w") as tar:
        for root, _, files in os.walk(path):
            for file in files:
                target = os.path.join(root, file)
                tar.add(target, arcname=os.path.relpath(target, path))
    return buf.getvalue()


def compress_directory(source: Path, target: Path) -> None:
    compressor = zstd.ZstdCompressor()
    data = tar_directory(source)
    with target.open("wb+") as f:
        f.write(compressor.compress(data))


def build_configuration(job: aws.TrainingJob) -> Any:
    model = job.hyperparameters["model"]
    result = {
        "seed": job.hyperparameters["seed"],
    }

    if "training_time" in job.hyperparameters:
        result["training_time"] = job.hyperparameters["training_time"]
        result["hyperparameters"] = {
            "context_length_multiple": job.hyperparameters["context_length_multiple"],
            **{
                k[len(model) + 1 :]: v
                for k, v in job.hyperparameters.items()
                if k.startswith(f"{model}_")
            },
        }

    return result


def build_performance(job: aws.TrainingJob) -> Any:
    num_models = len(job.metrics["training_time"])

    hierarchy = {
        "training": ["training_time", "num_gradient_updates"],
        "evaluation": ["train_loss", "val_loss", "val_mean_weighted_quantile_loss"],
        "testing": ["mase", "smape", "nrmse", "nd", "mean_weighted_quantile_loss"],
    }
    rename = {
        "training_time": "duration",
        "mean_weighted_quantile_loss": "ncrps",
        "val_mean_weighted_quantile_loss": "val_ncrps",
    }
    integer_metrics = {"num_gradient_updates"}

    performances = [
        {
            group: {
                rename.get(item, item): (
                    int(job.metrics[item][i].item())
                    if item in integer_metrics
                    else job.metrics[item][i].item()
                )
                for item in items
            }
            for group, items in hierarchy.items()
            if any(len(job.metrics[item]) > 0 for item in items)
        }
        for i in range(num_models)
    ]

    result = {
        "meta": {
            "num_model_parameters": int(job.metrics["num_model_parameters"][0].item()),
            "latency": np.mean(job.metrics["latency"]).item(),
        },
        "performances": performances,
    }

    assert len(result["performances"]) in (1, 11)
    assert all(
        all(not math.isnan(v) for v in vv.values())
        for p in result["performances"]
        for vv in p.values()
    )

    return result


def check_target(target: Path) -> bool:
    return (
        (target / "config.json").exists()
        and (target / "performance.json").exists()
        and (target / "test_forecasts.tar").exists()
        and len(os.listdir(target)) in (3, 4)
        and (len(os.listdir(target)) == 3 or (target / "val_forecasts.tar").exists())
    )


def move_job(job: aws.TrainingJob, to: Path) -> None:
    # Get basic information
    config = build_configuration(job)
    performance = build_performance(job)

    # Get target path by using hash of configx
    config_hash = hashlib.md5(json.dumps(config).encode("utf-8")).hexdigest()
    target = to / job.hyperparameters["model"] / job.hyperparameters["dataset"] / config_hash

    # If the target exists, check if everything is fine, otherwise continue
    if target.exists():
        if not check_target(target):
            shutil.rmtree(target)
        else:
            return

    target.mkdir(parents=True, exist_ok=True)

    # Write basic information
    with (target / "config.json").open("w+") as f:
        json.dump(config, f, indent=4)
    with (target / "performance.json").open("w+") as f:
        json.dump(performance, f, indent=4)

    num_models = len(performance["performances"])

    # Write test predictions
    with tempfile.TemporaryDirectory() as tmp:
        for i in range(num_models):
            shutil.copyfile(
                # pylint: disable=protected-access
                job._cache_dir() / "artifacts" / "predictions" / f"model_{i}" / "forecasts.npz",
                os.path.join(tmp, f"forecasts_{i:02}.npz"),
            )
        with (target / "test_forecasts.tar").open("wb+") as f:
            f.write(tar_directory(tmp))

    # Write val predictions
    if num_models > 1:
        with tempfile.TemporaryDirectory() as tmp:
            for i in range(num_models):
                shutil.copyfile(
                    job._cache_dir()  # pylint: disable=protected-access
                    / "artifacts"
                    / "val_predictions"
                    / f"model_{i}"
                    / "forecasts.npz",
                    os.path.join(tmp, f"forecasts_{i:02}.npz"),
                )
            with (target / "val_forecasts.tar").open("wb+") as f:
                f.write(tar_directory(tmp))

    # Check if everything went well
    assert check_target(target)


@click.command()
@click.argument("experiment", type=str, nargs=1)
def main(experiment: str):
    # First, get IDs of all jobs
    analysis = aws.Analysis(experiment, only_completed=True)

    # Then, move the data appropriately
    base = Path.home() / "ts-bench"
    for job in tqdm(analysis):
        move_job(job, base)

    # And create a compressed file with all the configs and performances
    with (base / "metadata.tar.gz").open("wb+") as f:
        with tarfile.open(fileobj=f, mode="w:gz") as tar:
            for root, _, files in os.walk(base):
                for file in files:
                    if file not in ("config.json", "performance.json"):
                        continue
                    target = os.path.join(root, file)
                    tar.add(target, arcname=os.path.relpath(target, base))


if __name__ == "__main__":
    # pylint: disable=no-value-for-parameter
    main()
