'''
TRANSFORMS
==========
 + Composer
 + Copy
   - One-Hot
 + Dynamical Systems
   - Project X
   - Shadow Manifold
   - Single Delay Manifold

'''

import torch
from neurowave.registry import TRANSFORM_REGISTRY, register_transform


# ========
# COMPOSER
# ========

class ComposeTransform:
    def __init__(self, transforms):
        """
        transforms: list of callables of the form (traj, label) -> (traj, label)
        """
        self.transforms = transforms

    def __call__(self, data):
        for t in self.transforms:
            data = t(data)
        return data

@register_transform("compose")
def tf_compose(cfg):
    """
    Expects cfg.transforms to be a list of transform‐names, e.g.
      cfg.transforms = ["only_x", "normalize", "add_noise"]
    """
    # instantiate each registered transform with the config
    seq = [TRANSFORM_REGISTRY[name](cfg) for name in cfg['transforms']]
    return ComposeTransform(seq)



# COPY TASKS
# ==========

@register_transform('one_hot')
def tf_one_hot(cfg):
    """
    Transform to one-hot encode the labels.
    """
    def transform_fn(data):
        x, label = data
        # Assuming label is a single integer class index
        x = torch.nn.functional.one_hot(x.long().squeeze(), num_classes=cfg['num_classes'])
        return x.float(), label
    return transform_fn

# DYNAMICAL SYSTEMS
# =================

# Data Transformation
# -------------------

@register_transform('project_x')
def tf_project_x(cfg):
    """
    Transform to return only the x-coordinate of the Lorenz system.
    """
    def transform_fn(data):
        x, label = data
        return x[..., :1], label
    return transform_fn


# Label Transformation
# --------------------

@register_transform('shadow_manifold')
def tf_shadow_manifold(cfg):
    """
    Transform to return the shadow manifold representation of the Lorenz system.
    """
    def transform_fn(data):
        x, label = data
        label = torch.cat([label[:-20,:1], label[10:-10,:1], label[20:,:1]], dim=-1)
        return x[20:], label
    return transform_fn


@register_transform('single_delay_manifold')
def tf_shadow_manifold(cfg):
    """
    Transform to return the shadow manifold representation of the Lorenz system.
    """
    def transform_fn(data):
        x, label = data
        label = torch.cat([label[:-2,:1], label[1:-1,:1], label[2:,:1]], dim=-1)
        return x[2:], label
    return transform_fn

