from typing import List, Tuple

import metatensor.torch as mts
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import System
from omegaconf import DictConfig

from ..target_info import TargetInfo, get_energy_target_info, get_generic_target_info


def read_systems(filename: str) -> List[System]:
    """Read system information using metatensor.

    :param filename: name of the file to read

    :raises NotImplementedError: Serialization of systems is not yet
        available in metatensor.
    """
    raise NotImplementedError("Reading metatensor systems is not yet implemented.")


def _wrapped_metatensor_read(filename) -> TensorMap:
    try:
        return mts.load(filename)
    except Exception as e:
        raise ValueError(f"Failed to read '{filename}' with torch: {e}") from e


def read_energy(target: DictConfig) -> Tuple[TensorMap, TargetInfo]:
    tensor_map = _wrapped_metatensor_read(target["read_from"])

    if len(tensor_map) != 1:
        raise ValueError("Energy TensorMaps should have exactly one block.")

    add_position_gradients = target["forces"]
    add_strain_gradients = target["stress"] or target["virial"]
    target_info = get_energy_target_info(
        target, add_position_gradients, add_strain_gradients
    )

    # now check all the expected metadata (from target_info.layout) matches
    # the actual metadata in the tensor maps
    _check_tensor_map_metadata(tensor_map, target_info.layout)

    selections = [
        Labels(
            names=["system"],
            values=torch.tensor([[int(i)]]),
        )
        for i in torch.unique(
            torch.concatenate(
                [block.samples.column("system") for block in tensor_map.blocks()]
            )
        )
    ]
    tensor_maps = mts.split(tensor_map, "samples", selections)
    return tensor_maps, target_info


def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]:
    tensor_map = _wrapped_metatensor_read(target["read_from"])

    for block in tensor_map.blocks():
        if len(block.gradients_list()) > 0:
            raise ValueError("Only energy targets can have gradient blocks.")

    target_info = get_generic_target_info(target)
    _check_tensor_map_metadata(tensor_map, target_info.layout)

    # make sure that the properties of the target_info.layout also match the
    # actual properties of the tensor maps
    target_info.layout = _empty_tensor_map_like(tensor_map)

    selections = [
        Labels(
            names=["system"],
            values=torch.tensor([[int(i)]]),
        )
        for i in torch.unique(tensor_map.block(0).samples.column("system"))
    ]
    tensor_maps = mts.split(tensor_map, "samples", selections)
    return tensor_maps, target_info


def _check_tensor_map_metadata(tensor_map: TensorMap, layout: TensorMap):
    if tensor_map.keys != layout.keys:
        raise ValueError(
            f"Unexpected keys in metatensor targets: "
            f"expected: {layout.keys} "
            f"actual: {tensor_map.keys}"
        )
    for key in layout.keys:
        block = tensor_map.block(key)
        block_from_layout = layout.block(key)
        if block.samples.names != block_from_layout.samples.names:
            raise ValueError(
                f"Unexpected samples in metatensor targets: "
                f"expected: {block_from_layout.samples.names} "
                f"actual: {block.samples.names}"
            )
        if block.components != block_from_layout.components:
            raise ValueError(
                f"Unexpected components in metatensor targets: "
                f"expected: {block_from_layout.components} "
                f"actual: {block.components}"
            )
        # the properties can be different from those of the default `TensorMap`
        # given by `get_generic_target_info`, so we don't check them
        if set(block.gradients_list()) != set(block_from_layout.gradients_list()):
            raise ValueError(
                f"Unexpected gradients in metatensor targets: "
                f"expected: {block_from_layout.gradients_list()} "
                f"actual: {block.gradients_list()}"
            )
        for name in block_from_layout.gradients_list():
            gradient_block = block.gradient(name)
            gradient_block_from_layout = block_from_layout.gradient(name)
            if gradient_block.samples.names != gradient_block_from_layout.samples.names:
                raise ValueError(
                    f"Unexpected samples in metatensor targets "
                    f"for `{name}` gradient block: "
                    f"expected: {gradient_block_from_layout.samples.names} "
                    f"actual: {gradient_block.samples.names}"
                )
            if gradient_block.components != gradient_block_from_layout.components:
                raise ValueError(
                    f"Unexpected components in metatensor targets "
                    f"for `{name}` gradient block: "
                    f"expected: {gradient_block_from_layout.components} "
                    f"actual: {gradient_block.components}"
                )


def _empty_tensor_map_like(tensor_map: TensorMap) -> TensorMap:
    new_keys = tensor_map.keys
    new_blocks: List[TensorBlock] = []
    for block in tensor_map.blocks():
        new_block = _empty_tensor_block_like(block)
        new_blocks.append(new_block)
    return TensorMap(keys=new_keys, blocks=new_blocks)


def _empty_tensor_block_like(tensor_block: TensorBlock) -> TensorBlock:
    new_block = TensorBlock(
        values=torch.empty(
            (0,) + tensor_block.values.shape[1:],
            dtype=torch.float64,  # metatensor can't serialize otherwise
            device=tensor_block.values.device,
        ),
        samples=Labels(
            names=tensor_block.samples.names,
            values=torch.empty(
                (0, tensor_block.samples.values.shape[1]),
                dtype=tensor_block.samples.values.dtype,
                device=tensor_block.samples.values.device,
            ),
        ),
        components=tensor_block.components,
        properties=tensor_block.properties,
    )
    for gradient_name, gradient in tensor_block.gradients():
        new_block.add_gradient(gradient_name, _empty_tensor_block_like(gradient))
    return new_block
