'''
DATA
====

+ Data Class
  - Copy Task(s)
  - Dynamical System(s)
'''

# Additional imports for dynamical systems
from functools import partial
import numpy as np
from scipy.integrate import solve_ivp
from torch.utils.data import Dataset
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from omegaconf.dictconfig import DictConfig

from neurowave.transforms import *
from neurowave.registry import register_dataset, TRANSFORM_REGISTRY


# ==========
# COPY TASKS
# ==========

# Generators
# ==========

def copy_simple(seq_len=20, memory_len=10):
    data = np.random.randint(low = 1, high = 9, size = memory_len)

    X = np.zeros(seq_len + memory_len)
    X[:memory_len] = data

    Y = np.zeros(seq_len + memory_len)
    Y[-memory_len:] = data 

    return torch.tensor(X).float().unsqueeze(-1), torch.tensor(Y).long().unsqueeze(-1)

def copy_selective(seq_len=20, memory_len=10, garbage_len=5):
    data = np.random.randint(low = 1, high = 9, size = memory_len)
    positions = np.sort(np.array(np.random.choice(seq_len-memory_len-garbage_len, memory_len, replace=False)))
 
    X = np.zeros(seq_len + memory_len + garbage_len)
    X[positions] = data
    # indicator for model to begin outputting memorized tokens
    X[-(memory_len + 1)] = 9 

    Y = np.zeros(seq_len + memory_len + garbage_len)
    Y[-memory_len:] = data

    return torch.tensor(X).float().unsqueeze(-1), torch.tensor(Y).long().unsqueeze(-1)

def copy_ordered(seq_len=20, memory_len=10, garbage_len=5):
    """
    Ordered copy task where the model must output the memorized sequence in the order it was presented.
    """
    data = np.random.randint(low=1, high=9, size=memory_len)
    positions = np.sort(np.array(np.random.choice(seq_len - memory_len - garbage_len, memory_len, replace=False)))
    ordered_data = np.sort(data)  # Sort the data to create an ordered sequence

    X = np.zeros(seq_len + memory_len + garbage_len)
    X[positions] = data
    # indicator for model to begin outputting memorized tokens
    X[-(memory_len + 1)] = 9

    Y = np.zeros(seq_len + memory_len + garbage_len)
    Y[-memory_len:] = ordered_data

    return torch.tensor(X).float().unsqueeze(-1), torch.tensor(Y).long().unsqueeze(-1)



# Dataset Class
# =============

@register_dataset('copy')
class CopyTask(Dataset):

    generators = {'simple': copy_simple,
                  'selective': copy_selective,
                  'ordered': copy_ordered,
                  }

    def __init__(self, cfg: DictConfig):
        # Extract parameters from the configuration dictionary
        self.task = cfg.get('task', 'simple')
        self.seq_len = cfg.get('seq_len', 20)
        self.memory_len = cfg.get('memory_len', 10)
        self.num_samples = cfg.get('num_samples', 1000)
        self.transform = cfg.get('transform', None)

        if self.transform is not None:
            if isinstance(self.transform, str):
                self.transform = TRANSFORM_REGISTRY[self.transform]()
            else:

                transform = TRANSFORM_REGISTRY['compose']({
                    'transforms': list(self.transform),
                    'num_classes': 10  # Assuming 10 classes for one-hot encoding
                })
                self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        x, label = self.generators[self.task](self.seq_len, self.memory_len)
        if self.transform:
            x, label = self.transform((x, label))
        return x, label


# =================
# DYNAMICAL SYSTEMS
# =================

# Systems
# =======

def lorenz(t, x, sigma=10., rho=28., beta=8./3.):
  dx = np.empty_like(x)
  dx[0] = sigma * (x[1] - x[0])
  dx[1] = x[0] * (rho - x[2]) - x[1]
  dx[2] = x[0] * x[1] - beta * x[2]
  return dx

def rossler(t, x, a=0.2, b=0.2, c=5.7):
    dx = np.empty_like(x)
    dx[0] = -x[1] - x[2]
    dx[1] = x[0] + a * x[1]
    dx[2] = b + x[2] * (x[0] - c)
    return dx

def chen(t, x, a=35.0, b=3.0, c=28.0):
    dx = np.empty_like(x)
    dx[0] = a * (x[1] - x[0])
    dx[1] = (c - a) * x[0] - x[0] * x[2] + c * x[1]
    dx[2] = x[0] * x[1] - b * x[2]
    return dx

def thomas(t, x, b=0.19):
    dx = np.empty_like(x)
    dx[0] = np.sin(x[1]) - b * x[0]
    dx[1] = np.sin(x[2]) - b * x[1]
    dx[2] = np.sin(x[0]) - b * x[2]
    return dx

def sprott_b(t, x):
    dx = np.empty_like(x)
    dx[0] = x[1] * x[2]
    dx[1] = x[0] - x[1]
    dx[2] = 1 - x[0] * x[1]
    return dx

def sink(t, x, omega=1.2, alpha=0.8, lmbda=0.8):
    dx = np.empty_like(x)
    dx[0] = -alpha * x[0] - omega * x[1]
    dx[1] = omega * x[0] - alpha * x[1]
    dx[2] = -lmbda * x[0]
    return dx


# Data Class
# ===========

@register_dataset('dynamical_systems')
class DynamicalSystem(Dataset):

    systems = {'lorenz': lorenz,
               'rossler': rossler,
               'chen': chen,
               'thomas': thomas,
               'sprott_b': sprott_b,
               'sink': sink}

    scale = {'lorenz': 15.0,
             'rossler': 10.0,
             'chen': 35.0,
             'thomas': 1.0,
             'sprott_b': 1.0,
             'sink': 1.0}

    def __init__(self, cfg: DictConfig):
        # Extract parameters from the configuration dictionary
        self.task = cfg.get('task', 'lorenz')
        self.trajectory_length = cfg.get('trajectory_length', 300)
        self.num_trajectories = cfg.get('num_trajectories', 1000)
        self.dt = cfg.get('dt', 0.01)
        self.transform = cfg.get('transform', None)

        self.system_kwargs = cfg.get('system_kwargs', {})
        self.derivatives = partial(self.systems.get(self.task), **self.system_kwargs)

        if self.transform is not None:
            if isinstance(self.transform, str):
                self.transform = TRANSFORM_REGISTRY[self.transform]()
            else:

                transform = TRANSFORM_REGISTRY['compose']({
                    'transforms': list(self.transform),
                })
                self.transform = transform

    def __len__(self):
        return self.num_trajectories

    def __getitem__(self, idx):

        t0, tf = 0.0, (self.trajectory_length - 1) * self.dt
        t_eval = np.linspace(t0, tf, self.trajectory_length)
        initial_state = np.random.uniform(-1, 1, size=3) * self.scale.get(self.task, 1.0)

        sol = solve_ivp(
            self.derivatives,
            (t0, tf),
            initial_state,
            method='DOP853',  # A high-order method for better accuracy
            t_eval=t_eval,
            rtol=1e-3,
            atol=1e-3
        )
        if not sol.success:
            raise RuntimeError(f"Integration failed for {self.task}: {sol.message}")

        x, label = torch.tensor(sol.y.T, dtype=torch.float32), torch.tensor(sol.y.T, dtype=torch.float32)
        if self.transform:
            x, label = self.transform((x, label))
        #self.trajectory_length = x.shape[0]
        return x, label


'''
TESTING
=======

  + Copy Taks Dataset

  + Dynamical System Dataset
     - Raw Dataset
     - Transformed Dataset
     - DataLoader

'''

if __name__ == '__main__':
    import os
    from omegaconf import DictConfig
    # Ensure output directory exists
    os.makedirs('./out/test/datasets/', exist_ok=True)

    # COPY TASKS TEST
    # ===============
    print('Testing Copy Tasks')

    # Simple Copy Task Testing
    # ------------------------
    print('\tTesting Simple Copy Task')
    copy_task = CopyTask(DictConfig(dict(task='simple', sample_len=20, memory_len=10, num_samples=1000)))
    x, label = copy_task.__getitem__(0)
    print(f'\t\tData Shape: {x.shape}, Label Shape: {label.shape}')
    print(f'\t\tData[:10] = {x[:10].flatten()}')
    print(f'\t\tLabel[-10:] = {label[-10:].flatten()}')

    # Selective Copy Task Testing
    # ---------------------------
    print('\tTesting Selective Copy Task')
    copy_task = CopyTask(DictConfig(dict(task='selective', seq_len=20, memory_len=5, num_samples=1000)))
    x, label = copy_task.__getitem__(0)
    print(f'\t\tData Shape: {x.shape}, Label Shape: {label.shape}')
    print(f'\t\tData[:10] = {x[:10].flatten()}')
    print(f'\t\tLabel[-10:] = {label[-10:].flatten()}')

    # Ordered Copy Task Testing
    # ---------------------------
    print('\tTesting Ordered Copy Task')
    copy_task = CopyTask(DictConfig(dict(task='ordered', seq_len=20, memory_len=5, num_samples=1000)))
    x, label = copy_task.__getitem__(0)
    print(f'\t\tData Shape: {x.shape}, Label Shape: {label.shape}')
    print(f'\t\tData[:10] = {x[:10].flatten()}')
    print(f'\t\tLabel[-10:] = {label[-10:].flatten()}')


    # Transform Testing
    # ----------------
    from utils import plot_one_hot
    print('\tTesting One Hot Transform')
    transform = ['one_hot']
    copy_task = CopyTask(DictConfig(dict(task='simple', sample_len=20, memory_len=10, num_samples=1000, transform=transform)))
    x, label = copy_task.__getitem__(0)
    # Save a sample plot
    print(f'\t\tTransformed Data Shape: {x.shape}, Transformed Label Shape: {label.shape}')
    plot_one_hot(x, title='One Hot Encoded Data Sample')
    plt.tight_layout()
    plt.savefig('./out/test/datasets/simple_copy_task_one_hot.png')
    plt.close('all')

    one_hot_label = torch.nn.functional.one_hot(label.squeeze().long(), num_classes=10).float()
    plot_one_hot(one_hot_label, title='One Hot Encoded Label Sample')
    plt.tight_layout()
    plt.savefig('./out/test/datasets/simple_copy_task_label.png')

    copy_task = CopyTask(DictConfig(dict(task='selective', sample_len=20, memory_len=5, num_samples=1000, transform=transform)))
    x, label = copy_task.__getitem__(0)
    # Save a sample plot
    print(f'\t\tTransformed Data Shape: {x.shape}, Transformed Label Shape: {label.shape}')
    plot_one_hot(x, title='One Hot Encoded Data Sample')
    plt.tight_layout()
    plt.savefig('./out/test/datasets/selective_copy_task_one_hot.png')
    plt.close('all')

    one_hot_label = torch.nn.functional.one_hot(label.squeeze().long(), num_classes=10).float()
    plot_one_hot(one_hot_label, title='One Hot Encoded Label Sample')
    plt.tight_layout()
    plt.savefig('./out/test/datasets/selective_copy_task_label.png')

    copy_task = CopyTask(DictConfig(dict(task='ordered', sample_len=20, memory_len=5, num_samples=1000, transform=transform)))
    x, label = copy_task.__getitem__(0)
    # Save a sample plot
    print(f'\t\tTransformed Data Shape: {x.shape}, Transformed Label Shape: {label.shape}')
    plot_one_hot(x, title='One Hot Encoded Data Sample')
    plt.tight_layout()
    plt.savefig('./out/test/datasets/ordered_copy_task_one_hot.png')
    plt.close('all')

    one_hot_label = torch.nn.functional.one_hot(label.squeeze().long(), num_classes=10).float()
    plot_one_hot(one_hot_label, title='One Hot Encoded Label Sample')
    plt.tight_layout()
    plt.savefig('./out/test/datasets/ordered_copy_task_label.png')

    # Data Loading Test
    # -----------------
    print('Testing DataLoader for Copy Tasks')
    from torch.utils.data import DataLoader
    copy_task = CopyTask(DictConfig(dict(task='simple', sample_len=20, memory_len=10, num_samples=1000, transform=transform)))
    dataloader = DataLoader(copy_task, batch_size=32, shuffle=True)
    print(f'\tDataLoader Length: {len(dataloader)}')
    print(f'\tDataLoader Batch Size: {dataloader.batch_size}')
    for i, (x, label) in enumerate(dataloader):
        print(f'\tBatch {i}: Data Shape: {x.shape}, Label Shape: {label.shape}')
        for j, (x_, label_) in enumerate(zip(x, label)):
            print(f'\t\tElement {j} Data Shape: {x_.shape}, Label Shape: {label_.shape}')
            plot_one_hot(x_, title='DataLoader Batch One Hot Encoded Data')
            plt.tight_layout()
            plt.savefig(f'./out/test/datasets/copy_task_loader_batch_{i}_{j}.png')
            plt.close('all')
            if j >= 2:  # Limit to first 2 elements in each batch for brevity
                break
        break

    # DYNAMICAL SYSTEMS TEST
    # ======================

    # Raw Data Test
    # -------------
    systems = ['lorenz', 'chen', 'thomas', 'sprott_b', 'sink']
    print('Testing Dynamical Systems')
    from utils import plot_3d_trajectory
    for system in systems:
        print(f"\tTesting system: {system}")
        dynamical_system = DynamicalSystem(DictConfig(dict(task=system, trajectory_length=1000, dt=0.01)))
        x, label = dynamical_system.__getitem__(0)
        print(f'\tData Shape: {x.shape}, Label Shape: {label.shape}')
        plot_3d_trajectory(label, title=f'Test {system.upper()} System')
        plt.tight_layout()
        plt.savefig(f'./out/test/datasets/{system.lower()}_attractor.png')
        plt.close('all')


    # Transformed Data Test
    # ----------------------

    print('Testing Transforms')

    # Single Coordinate Delay Test
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    print('\tSingle Delay Transforms')
    transform = ['project_x', 'single_delay_manifold']
    for i,system in enumerate(systems):
        print(f"\t\tTesting single delay on system: {system}")
        dynamical_system = DynamicalSystem(DictConfig(dict(task=system, trajectory_length=1000, dt=0.01, transform=transform)))
        x, label = dynamical_system.__getitem__(0)
        if i == 0:
            print(f'\t\tData[:4] = {x[:4]}')
            print(f'\t\tLabel[:4] = {label[:4]}')

        print(f'\t\tTransformed Data Shape: {x.shape}, Transformed Label Shape: {label.shape}')
        plot_3d_trajectory(label, title=f'Transformed {system.upper()} System')
        plt.tight_layout()
        plt.savefig(f'./out/test/datasets/{system.lower()}_single_delay_attractor.png')
        plt.close('all')

    # Shadow Manifold Test
    # ~~~~~~~~~~~~~~~~~~~~
    print('\tSingle Delay Transforms')
    transform = ['project_x', 'shadow_manifold']
    for i,system in enumerate(systems):
        print(f"\t\tTesting shadow manifold on system: {system}")
        dynamical_system = DynamicalSystem(DictConfig(dict(task=system, trajectory_length=1000, dt=0.01, transform=transform)))
        x, label = dynamical_system.__getitem__(0)
        if i == 0:
            print(f'\t\tData[:4] = {x[:4]}')
            print(f'\t\tLabel[:4] = {label[:4]}')

        print(f'\t\tTransformed Data Shape: {x.shape}, Transformed Label Shape: {label.shape}')
        plot_3d_trajectory(label, title=f'Transformed {system.upper()} System')
        plt.tight_layout()
        plt.savefig(f'./out/test/datasets/{system.lower()}_shadow_manifold_attractor.png')
        plt.close('all')


    # Data Loading Test
    # -----------------
    print('Testing DataLoader')

    from torch.utils.data import DataLoader
    transform = ['project_x', 'shadow_manifold']
    dynamical_system = DynamicalSystem(DictConfig(dict(task='lorenz', trajectory_length=1000, dt=0.01, transform=transform)))
    dataloader = DataLoader(dynamical_system, batch_size=32, shuffle=True)
    print(f'\tDataLoader Length: {len(dataloader)}')
    print(f'\tDataLoader Batch Size: {dataloader.batch_size}')
    for x, label in dataloader:
        print(f'\tBatch {i}: Data Shape: {x.shape}, Label Shape: {label.shape}')
        for i, (x_, label_) in enumerate(zip(x, label)):
            print(f'\t\tElement {i} Data Shape: {x_.shape}, Label Shape: {label_.shape}')
            plot_3d_trajectory(label_, title='DataLoader Batch Trajectory')
            plt.tight_layout()
            plt.savefig(f'./out/test/datasets/lorenz_loader_batch_{i}.png')
            plt.close('all')
            if i >= 2:  # Limit to first 2 batches for brevity
                break
        break
