# Copyright (c) OpenMMLab. All rights reserved.
# Modified from original implementation to support flexible application
import numpy as np
from collections import abc
from copy import deepcopy
from itertools import chain
from numbers import Number
from typing import Any, Sequence, Union, List

import mmengine
import torch
from mmengine.structures import BaseDataElement


class DataSample(BaseDataElement):
    """A data structure inherits from mmengine.BaseDataElement.
    The original mmengine.BaseDataElement differentiates data keys and metainfo keys,
    in our implementation, we `only use data keys` for simplicity.

    You can view DataSample as a special dict to store a batch of data or split a dict of batch data into a batch list of samples.
    For example, u have a data dict named batch_data in following format:
    data_dict: DataSample = {
        "images": torch.randn([B, H, W, C]),
        "image_metadata": [
            "shape" : [[H1, W1], [H2, W2], ..., [HB, WB]]
        ]
    }
    U can split this dict into a list of dict with length of batch size:
    data_samples = data_dict.split()
    data_samples:
    [
        {
            "images": torch.randn([H1, W1, C]),
            "image_metadata": [
                "shape" : [H1, W1]
            ]
        },
        {
            "images": torch.randn([H2, W2, C]),
            "image_metadata": [
                "shape" : [H2, W2]
            ]
        },
        ...,
        {
            "images": torch.randn([HB, WB, C]),
            "image_metadata": [
                "shape" : [HB, WB]
            ]
        },
    ]

    Inversely, u can stack this list of dict into a dict with length of batch size:
    data_dict = DataSample.stack(data_samples)
    data_dict:
    {
        "images": torch.randn([B, H, W, C]),
        "image_metadata": [
            "shape" : [[H1, W1], [H2, W2], ..., [HB, WB]]
        ]
    }

    """

    def __getitem__(self, key: str) -> Any:
        """
        Get field via dictionary-style access.
        """
        try:
            return getattr(self, key)
        except AttributeError:
            raise KeyError(f"Field '{key}' not found in DataSample.")

    def __setitem__(self, key: str, value: Any) -> None:
        """
        Set field via dictionary-style access.
        """
        setattr(self, key, value)

    @classmethod
    def stack(cls, data_samples: Sequence["DataSample"]) -> "DataSample":
        """
        Stack a list of DataSample instances into one.
        Tensors with identical shapes are batched via torch.stack.
        Nested BaseDataElement and dict are recursively stacked.
        Other values are collected into lists.
        """
        if not data_samples:
            raise ValueError("No data samples to stack.")

        reference_keys = set(data_samples[0].keys())
        for ds in data_samples:
            if set(ds.keys()) != reference_keys:
                raise ValueError("All DataSample instances must have identical keys.")

        def recurse(vals: List[Any]) -> Any:
            if all(isinstance(v, torch.Tensor) for v in vals):
                shapes = {v.shape for v in vals}
                if len(shapes) == 1:
                    return torch.stack(vals, dim=0)
                return vals
            if all(isinstance(v, (dict, BaseDataElement)) for v in vals):
                out = BaseDataElement()
                for k in vals[0].keys():
                    sub_vals = [
                        v.get(k) if isinstance(v, BaseDataElement) else v[k]
                        for v in vals
                    ]
                    setattr(out, k, recurse(sub_vals))
                return out
            if all(isinstance(v, (list, tuple)) for v in vals):
                length = len(vals[0])
                if all(len(v) == length for v in vals):
                    return [recurse([v[i] for v in vals]) for i in range(length)]
            return vals

        result = cls()
        for k in reference_keys:
            vals = [ds.get(k) for ds in data_samples]
            setattr(result, k, recurse(vals))
        return result

    def split(self, allow_nonseq_value: bool = True) -> List["DataSample"]:
        """
        Split batched fields into individual DataSample instances.
        Tensors split on dim=0, lists element-wise, nested BaseDataElement recursively,
        others copied if allowed.
        """
        batch_size = None
        for v in self.values():
            if isinstance(v, torch.Tensor):
                batch_size = v.shape[0]
                break
            if isinstance(v, list):
                batch_size = len(v)
                break
        if batch_size is None:
            raise ValueError("Cannot infer batch size; no tensor or list fields.")

        samples = [self.__class__() for _ in range(batch_size)]

        def recurse_split(val: Any) -> List[Any]:
            if isinstance(val, torch.Tensor):
                if val.shape[0] != batch_size:
                    raise ValueError("Tensor batch size mismatch.")
                return [val[i] for i in range(batch_size)]
            if isinstance(val, list):
                if len(val) != batch_size:
                    raise ValueError("List batch size mismatch.")
                return val
            if isinstance(val, (dict, BaseDataElement)):
                parts = [BaseDataElement() for _ in range(batch_size)]
                items_iter = (
                    val.items()
                    if isinstance(val, dict)
                    else ((k, val.get(k)) for k in val.keys())
                )
                for k, sub in items_iter:
                    splits = recurse_split(sub)
                    for i in range(batch_size):
                        setattr(parts[i], k, splits[i])
                return parts
            if allow_nonseq_value:
                return [val for _ in range(batch_size)]
            raise ValueError(f"Field type {type(val)} not splittable.")

        for k, _ in self.items():
            v = self.get(k)
            try:
                split_vals = recurse_split(v)
            except:
                raise ValueError(f"Field '{k}: {v}' cannot be split.")

            for i, val in enumerate(split_vals):
                setattr(samples[i], k, val)
        return samples


if __name__ == "__main__":
    import numpy as np

    print("Testing DataSample.stack...")
    # Stack test
    ds1 = DataSample(a=torch.tensor([1, 2]), b=BaseDataElement(x=torch.tensor([3, 4])))
    ds2 = DataSample(a=torch.tensor([5, 6]), b=BaseDataElement(x=torch.tensor([7, 8])))
    stacked = DataSample.stack([ds1, ds2])
    assert isinstance(
        stacked.a, torch.Tensor
    ), "Field 'a' should be a Tensor after stacking"
    assert stacked.a.shape == (2, 2), f"Expected shape (2,2), got {stacked.a.shape}"
    assert torch.equal(
        stacked.a, torch.tensor([[1, 2], [5, 6]])
    ), "Stacked values for 'a' mismatch"
    assert isinstance(
        stacked.b, BaseDataElement
    ), "Field 'b' should be BaseDataElement after stacking"
    assert torch.equal(
        stacked.b.x, torch.tensor([[3, 4], [7, 8]])
    ), "Stacked values for 'b.x' mismatch"
    print("DataSample.stack tests passed.")

    print("Testing DataSample.split...")
    # Split test
    sample = DataSample(
        img=torch.tensor([[10, 20], [30, 40]]),
        meta=BaseDataElement(vals=[100, 200], info=BaseDataElement(scalar=[0.1, 0.2])),
    )
    parts = sample.split()
    assert (
        isinstance(parts, list) and len(parts) == 2
    ), "split should return list of length 2"
    p0, p1 = parts
    # First part
    assert torch.equal(p0.img, torch.tensor([10, 20])), "First split img mismatch"
    assert isinstance(
        p0.meta, BaseDataElement
    ), "meta should be BaseDataElement in split"
    assert p0.meta.vals == 100, "First split meta.vals mismatch"
    assert isinstance(
        p0.meta.info, BaseDataElement
    ), "meta.info should be BaseDataElement"
    assert p0.meta.info.scalar == 0.1, "First split nested scalar mismatch"
    # Second part
    assert torch.equal(p1.img, torch.tensor([30, 40])), "Second split img mismatch"
    assert p1.meta.vals == 200, "Second split meta.vals mismatch"
    assert p1.meta.info.scalar == 0.2, "Second split nested scalar mismatch"
    print("DataSample.split tests passed.")
