#!/usr/bin/env python
import numpy as np
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from dmd import DMD


def optimize_n_delay(data,
                     delay_range,
                     project_state=False,
                     variance=0.95,
                     embed_input=True):
    """
    Finds the best delay for DMD by comparing prediction error
    across a range of delays.

    Parameters:
    - data: np.ndarray, system state data (T × N) or (K × T × N)
    - delay_range: iterable of int delays to evaluate
    - project_state: bool, whether to PCA‐project data first
    - variance: float, retained variance for PCA if project_state=True
    - embed_input: bool, unused for plain DMD but kept for interface

    Returns:
    - best_delay: int, optimal delay minimizing prediction error
    - test_mase: np.ndarray of errors for each delay in delay_range
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    delays = np.array(delay_range)

    # Optional PCA projection
    if project_state:
        data = PCA_down_data(data, var_explained=variance)

    # Train/test split (last 20% as test)
    if data.ndim == 2:
        T = data.shape[0]
        split = int(T * 0.8)
        train_data, test_data = data[:split], data[split:]
    else:  # shape K×T×N
        T = data.shape[1]
        split = int(T * 0.8)
        train_data = data[:, :split, :].reshape(-1, data.shape[2])
        test_data = data[:, split:, :].reshape(-1, data.shape[2])

    errors = []
    for d in delays:
        err = _evaluate_delay(train_data, test_data, d, device)
        errors.append(err)

    errors = np.array(errors)
    valid = ~np.isnan(errors)
    best_idx = np.argmin(errors[valid])
    best_delay = delays[valid][best_idx]

    return int(best_delay), errors


def _evaluate_delay(train_data, test_data, delay, device):
    """
    Fit a DMD model with given delay and compute MASE on test_data.
    """
    model = DMD(train_data, n_delays=delay, device=device)
    model.fit()
    pred = model.predict(test_data=test_data)
    return mase(test_data, pred.cpu())


def PCA_down_data(data,
                  var_explained=0.95,
                  min_dim=10,
                  return_n_components=False,
                  num_PCs='mean',
                  centering=True,
                  whiten=False):
    """
    PCA‐reduce each dataset to a common number of PCs that explain
    `var_explained` variance (at least min_dim).
    """
    def _compute_n(d):
        p = PCA().fit(d)
        cumvar = np.cumsum(p.explained_variance_ratio_)
        needed = np.searchsorted(cumvar, var_explained) + 1
        return max(needed, min_dim)

    single = not isinstance(data, list)
    arrs = data if not single else [data]
    shapes = [a.shape if a.ndim == 3 else None for a in arrs]
    flat = []
    for a in arrs:
        if a.ndim == 3:
            K, T, N = a.shape
            flat.append(a.reshape(K * T, N))
        else:
            flat.append(a)

    n_comps = [_compute_n(f) for f in flat]
    if num_PCs == 'mean':
        nc = int(round(np.mean(n_comps)))
    elif num_PCs == 'max':
        nc = int(max(n_comps))
    elif num_PCs == 'min':
        nc = int(min(n_comps))
    else:
        raise ValueError("num_PCs must be 'mean', 'max', or 'min'")

    proj = []
    for f in flat:
        if centering:
            f = f - f.mean(axis=0)
        pca = PCA(n_components=nc, whiten=whiten)
        proj.append(pca.fit_transform(f))

    out = []
    for p, shape in zip(proj, shapes):
        if shape:
            K, T, _ = shape
            out.append(p.reshape(K, T, -1))
        else:
            out.append(p)

    result = out[0] if single else out
    if return_n_components:
        return result, n_comps
    return result


def torch_convert(x):
    return torch.from_numpy(x) if isinstance(x, np.ndarray) else x


def mae(x, y):
    x, y = torch_convert(x), torch_convert(y)
    return torch.abs(x - y).mean().item()


def mase(true_vals, pred_vals):
    t = torch_convert(true_vals)
    p = torch_convert(pred_vals)
    # persistence baseline: shift by one in time
    if t.ndim == 2:
        base = mae(t[:-1], t[1:])
    else:
        base = mae(t[:, :-1], t[:, 1:])
    return mae(t, p) / base
