import torch
import torch.nn.functional as F
import math
import pandas as pd
from datetime import datetime
from typing import Dict, Any, List, Tuple

def load_tsf_as_timeseries(
    path: str,
    encodings: Tuple[str, ...] = ("utf-8", "windows-1252", "latin-1"),
    return_meta: bool = False,
) -> pd.DataFrame | Tuple[pd.DataFrame, Dict[str, Any], str]:
    """
    Parse a .tsf file into a wide pandas DataFrame:
      - index = timestamp (DatetimeIndex)
      - columns = series names (e.g., T1, T2, ...)
      - values = numeric readings
    Robust to non-UTF8 encodings (tries utf-8, then windows-1252, then latin-1).

    Args
    ----
    path: str
        Path to the .tsf file.
    encodings: tuple[str, ...]
        Ordered encodings to try.
    return_meta: bool
        If True, also returns (meta_dict, used_encoding).

    Returns
    -------
    df or (df, meta, used_encoding)
    """
    def _read_lines(p: str, encs: Tuple[str, ...]) -> Tuple[List[str], str]:
        last_err = None
        for enc in encs:
            try:
                with open(p, "r", encoding=enc, errors="strict") as f:
                    return f.readlines(), enc
            except UnicodeDecodeError as e:
                last_err = e
                continue
        # Final fallback: replace invalid bytes to avoid hard failure
        with open(p, "r", encoding=encs[-1], errors="replace") as f:
            return f.readlines(), f"{encs[-1]} (errors=replace)"

    lines, used_encoding = _read_lines(path, encodings)

    meta: Dict[str, Any] = {}
    data_section = False
    series_list: List[pd.Series] = []

    # parse @frequency to build timestamps
    freq_map = {
        "daily": "D", "weekly": "W", "monthly": "MS", "yearly": "YS",
        "hourly": "H", "minutely": "T", "quarterly": "QS"
    }

    for raw in lines:
        line = raw.strip()
        if not line or line.startswith("#"):
            continue

        if line.startswith("@") and not data_section:
            if line.lower().startswith("@data"):
                data_section = True
            else:
                parts = line[1:].split(None, 1)  # remove '@', split into key and value
                key = parts[0].lower()
                val = parts[1].strip() if len(parts) > 1 else True
                meta[key] = val
            continue

        if data_section:
            # Expected format: series_name:start_timestamp:val1,val2,...
            parts = line.split(":", 2)
            if len(parts) != 3:
                # Skip malformed lines
                continue

            series_name, start_ts_str, values_str = parts

            # Parse numeric values, mapping "" or "?" to NaN
            vals: List[float] = []
            for v in values_str.split(","):
                v = v.strip()
                if v in ("", "?"):
                    vals.append(float("nan"))
                else:
                    try:
                        vals.append(float(v))
                    except ValueError:
                        # Strip potential thousands separators or weird whitespace
                        vv = v.replace("\u00A0", "").replace(",", "")
                        vals.append(float(vv) if vv else float("nan"))

            # Parse timestamp like "1996-03-18 00-00-00"
            try:
                start_dt = datetime.strptime(start_ts_str, "%Y-%m-%d %H-%M-%S")
            except ValueError:
                start_dt = pd.to_datetime(start_ts_str, errors="coerce")

            # If unparsable, create a simple integer index as a fallback
            if pd.isna(start_dt):
                idx = pd.RangeIndex(start=0, stop=len(vals), step=1)
            else:
                freq = freq_map.get(str(meta.get("frequency", "")).lower(), "D")
                idx = pd.date_range(start=start_dt, periods=len(vals), freq=freq)

            series = pd.Series(vals, index=idx, name=series_name)
            series_list.append(series)

    if not series_list:
        raise ValueError("No series parsed from the .tsf file. Check file format or encoding.")

    # Outer-join on time to keep full union of timestamps across series; sort the index.
    df = pd.concat(series_list, axis=1).sort_index()
    df.index.name = "date"
    # check no Nan is present in the dataframe, otherwise raise error
    if df.isna().all().all():
        raise ValueError("All values in the DataFrame are NaN. Check file content.")

    return (df, meta, used_encoding) if return_meta else df

def MinMaxScaler(data, return_scalers=False):
    """Min Max normalizer using torch.

    Args:
      - data: original data (torch.Tensor)

    Returns:
      - norm_data: normalized data
    """
    min_val = torch.min(data, dim=0).values
    max_val = torch.max(data, dim=0).values
    numerator = data - min_val
    denominator = max_val - min_val
    norm_data = numerator / (denominator + 1e-7)
    if return_scalers:
        return norm_data, min_val, max_val
    return norm_data


def MinMaxArgs(data, min, max):
    """
    Args:
        data: given data (torch.Tensor)
        min: given min value (torch.Tensor)
        max: given max value (torch.Tensor)

    Returns:
        min-max scaled data by given min and max
    """
    numerator = data - min
    denominator = max - min
    norm_data = numerator / (denominator + 1e-7)
    return norm_data



def compute_pad_left_right(pad_total: int):
    """
    Computes the left and right padding lengths given the total padding length.
    Parameters:
        pad_total (int): Total padding length.
    Returns:
        tuple: A tuple containing the left and right padding lengths.
    """
    # For symmetric padding, split the pad_total between left and right sides.
    pad_left = pad_total // 2
    pad_right = pad_total - pad_left
    return pad_left, pad_right


def get_pad_length_to_power_of_two(n: int):
    """
    Computes the padding length required to make the length of a signal a power of two.
    Parameters:
        n (int): Length of the signal.
    Returns:
        tuple: A tuple containing the left and right padding lengths.
    """
    # Compute the next power of two greater than or equal to n
    next_power = 2 ** math.ceil(math.log2(n))
    # Compute total number of padding elements required
    pad_total = next_power - n
    return compute_pad_left_right(pad_total)


def get_pad_length_to_match_seq_len(n: int, seq_len: int, stride: int = 1):
    """
    Computes the padding length required to make the length of a signal a multiple of seq_len.
    Parameters:
        n (int): Length of the signal.
        seq_len (int): Desired sequence length.
        stride (int): Stride length for the sliding window.
    Returns:
        tuple: A tuple containing the left and right padding lengths.
    """
    # how many full windows we’ll extract
    n_win = math.ceil((n - seq_len) / stride) + 1
    # total length required to fit exactly n_win windows
    L_req = (n_win - 1) * stride + seq_len
    # padding needed
    pad_total = max(0, L_req - n)
    return compute_pad_left_right(pad_total)


def pad_signal(signal, pad_left, pad_right, mode="reflect"):
    """
    Pads a signal tensor along the specified time dimension.
    Parameters:
        signal (torch.Tensor): Input signal tensor.
        pad_left (int): Number of elements to pad on the left.
        pad_right (int): Number of elements to pad on the right.
        mode (str): Padding mode. Default is 'reflect'.
    Returns:
        torch.Tensor: Padded signal tensor.
    """
    if pad_left == 0 and pad_right == 0:
        return signal
    return F.pad(signal, (pad_left, pad_right), mode=mode)


def unpad_signal(padded_signal, pad_left, pad_right):
    """
    Removes padding from a signal tensor.
    Parameters:
        padded_signal (torch.Tensor): Padded signal tensor.
        pad_left (int): Number of elements to remove from the left.
        pad_right (int): Number of elements to remove from the right.
    Returns:
        torch.Tensor: Signal tensor without padding.
    """
    if pad_left > 0:
        padded_signal = padded_signal[pad_left:, :]
    if pad_right > 0:
        padded_signal = padded_signal[:-pad_right, :]
    return padded_signal


def pad_timestamps(index, pad_left, pad_right):
    """
    Pads a DateTimeIndex to the next power of two in length.

    Parameters:
        index (pd.DatetimeIndex): Original DateTimeIndex.
        pad_left (int): Number of periods to pad on the left.
        pad_right (int): Number of periods to pad on the right.

    Returns:
        pd.DatetimeIndex: Padded DateTimeIndex.
    """
    if pad_left == 0 and pad_right == 0:
        return index

    # For DateTimeIndex, compute the time difference between the first two steps.
    dt = index[1] - index[0]
    # Create new times by linear extrapolation
    left_pad = pd.date_range(start=index[0] - pad_left * dt, periods=pad_left, freq=dt)
    right_pad = pd.date_range(start=index[-1] + dt, periods=pad_right, freq=dt)
    # Concatenate: note that .append preserves order for DateTimeIndex
    return left_pad.append(index.append(right_pad))


def ema_trend(x, alpha=0.1):
    """
    Applies exponential moving average smoothing along the time dimension.

    Args:
        x (torch.Tensor): Input tensor of shape (..., L, K), where L is the time dimension (second last),
                          and K is the feature dimension (last).
        alpha (float): Smoothing factor.

    Returns:
        torch.Tensor: Smoothed tensor of the same shape as input.
    """
    # Move time and feature dims to -2 and -1 if not already
    orig_shape = x.shape
    if x.ndim < 2:
        raise ValueError("Input tensor must have at least 2 dimensions (time, features).")
    L = x.shape[-2]
    K = x.shape[-1]
    leading_shape = x.shape[:-2]
    x_flat = x.reshape(-1, L, K)  # (B, L, K) where B = prod(leading_shape) or 1

    ema = torch.zeros_like(x_flat)
    ema[:, 0, :] = x_flat[:, 0, :]
    for t in range(1, L):
        ema[:, t, :] = alpha * x_flat[:, t, :] + (1 - alpha) * ema[:, t - 1, :]
    ema = ema.reshape(orig_shape)
    return ema


def flatten_windows(
    x, overlapping_stride: int = 1, avg_window: bool = True, eps: float = 0.5
):
    """
    Remove overlapping sequences from the input tensor with weighted averaging or
    centered-selection strategy.

    Parameters:
        x (torch.Tensor): Tensor of shape (n_windows, window_length, *features).
        overlapping_stride (int): Stride used when windows were created.
        avg_window (bool): Whether to average using distance-based weights (True)
                           or pick predictions from the window closest to the center (False).
        eps (float): Small value to avoid division by zero in averaging. Notice that the parameter controll also the smoothing of the weights.
                     eps = 1e-8 will be extremely peaky and will preserve the most high-frequency detail (i.e., it will be the least smooth).
                     eps = 0.1 to eps = 0.5: This will start to "blunt" the peak. The function will still be very sharp, but the center point's dominance will be reduced.
                     eps = 1.0: the unction 1 / (distance + 1) has a very natural hyperbolic decay. It's smooth but still gives significant priority to the center.
                     eps > 1.0 (e.g., eps = 5.0): As s gets larger, the function becomes broader and flatter. The difference in weight between the center and its neighbors becomes smaller, leading to more aggressive smoothing.

    Returns:
        torch.Tensor: Tensor of shape (total_length, *features) without overlaps.
    """
    n_win, L, *feat = x.shape
    total_len = (n_win - 1) * overlapping_stride + L

    device, dtype = x.device, x.dtype
    acc = torch.zeros((total_len, *feat), device=device, dtype=dtype)
    weight_sum = torch.zeros((total_len,), device=device, dtype=dtype)

    window_centers = torch.tensor(
        [i * overlapping_stride + (L - 1) / 2 for i in range(n_win)], device=device, dtype=dtype
    )

    for i in range(n_win):
        start = i * overlapping_stride
        end = start + L
        positions = torch.arange(start, end, device=device)
        distances = torch.abs(
            positions - window_centers[i]
        )  # distance to the window center

        if avg_window:
            # compute weight inversely proportional to distance (closer points get higher weight)
            weights = 1 / (distances + eps)
            weights = weights.unsqueeze(-1)

            acc[start:end] += x[i] * weights
            weight_sum[start:end] += weights.squeeze()

        else:
            # use current window's predictions if closer to its center than previous ones
            current_weights = distances
            previous_weights = weight_sum[start:end]
            mask = (previous_weights == 0) | (current_weights < previous_weights)

            acc[start:end][mask] = x[i][mask]
            weight_sum[start:end][mask] = current_weights[mask]

    if avg_window:
        # normalize accumulated values by the weights
        return acc / weight_sum.unsqueeze(-1)
    else:
        return acc


def normalize_minmax_neg_one_to_one(
    x, min_val=None, max_val=None, eps=1e-8, dims_to_reduce=(0, 1)
):
    """
    Normalize a tensor to [-1, 1] using min-max scaling across specified dimensions.
    Returns normalized tensor and the min/max used for rescaling.
    """
    if min_val is None or max_val is None:
        min_val = x.amin(dim=dims_to_reduce, keepdim=True)
        max_val = x.amax(dim=dims_to_reduce, keepdim=True)

    # Min-max normalize to [0, 1]
    x_norm = (x - min_val) / (max_val - min_val + eps)

    # Scale to [-1, 1]
    x_scaled = x_norm * 2 - 1

    return x_scaled, min_val, max_val


def unnormalize_minmax_neg_one_to_one(x_scaled, min_val, max_val, eps=1e-8):
    """
    Undo normalization from [-1, 1] back to original values using saved min and max.
    """
    # Shift from [-1, 1] to [0, 1]
    x_norm = (x_scaled + 1) / 2

    # Unscale to original range
    x_orig = x_norm * (max_val - min_val + eps) + min_val

    return x_orig


def pad_hw(
    x: torch.Tensor,
    target_h: int,
    target_w: int,
    mode: str = "constant",
    value: float = 0.0,
) -> torch.Tensor:
    """
    Pads the H and W dims of a tensor (B, T, C, H, W) up to target_h / target_w.
    • mode  : 'constant', 'replicate', 'reflect', or 'circular'  (see F.pad docs)
    • value : fill value when mode == 'constant'
    """
    if x.ndim != 5:
        raise ValueError("Expected shape (B, T, C, H, W)")

    _, _, _, h, w = x.shape
    if h > target_h or w > target_w:
        raise ValueError(
            f"Current size (H={h}, W={w}) exceeds targets "
            f"(target_h={target_h}, target_w={target_w}); crop first."
        )

    # Amount of padding we still need on each side
    pad_h = target_h - h  # total missing rows
    pad_w = target_w - w  # total missing cols

    # Split that padding symmetrically (put the +1 on the "bottom/right" side if odd)
    pad_top = pad_h // 2
    pad_bottom = pad_h - pad_top
    pad_left = pad_w // 2
    pad_right = pad_w - pad_left

    # For a 5-D tensor F.pad expects (w_left, w_right, h_top, h_bottom, d_front, d_back),
    # where the “depth” dimension here is T – we leave it unchanged with zeros.
    pad = (pad_left, pad_right, pad_top, pad_bottom, 0, 0)

    return F.pad(x, pad, mode=mode, value=value)


def unpad_hw(
    x: torch.Tensor,
    orig_h: int | None = None,
    orig_w: int | None = None,
    pads: tuple[int, int, int, int] | None = None,
) -> torch.Tensor:
    """
    Remove padding from tensor (B, T, C, H, W) that was added with `pad_hw`.

    Args
    ----
    x      : padded tensor
    orig_h : original height **before** padding         (mutually exclusive with `pads`)
    orig_w : original width  **before** padding
    pads   : 4-tuple (pad_left, pad_right, pad_top, pad_bottom)
             returned by `pad_hw`                       (mutually exclusive with `orig_*`)

    Returns
    -------
    x_cropped : tensor with the padding stripped off
    """
    if x.ndim != 5:
        raise ValueError("Expected shape (B, T, C, H, W)")

    _, _, _, h, w = x.shape

    if pads is not None and (orig_h is not None or orig_w is not None):
        raise ValueError("Specify either `pads` *or* (`orig_h`, `orig_w`), not both.")

    # ── Case 1: you know the pads exactly ────────────────────────────────
    if pads is not None:
        pad_left, pad_right, pad_top, pad_bottom = pads

    # ── Case 2: you only know the target size ───────────────────────────
    else:
        if orig_h is None or orig_w is None:
            raise ValueError("Need both `orig_h` and `orig_w` when `pads` is None.")

        if orig_h > h or orig_w > w:
            raise ValueError("`orig_h/orig_w` larger than padded tensor.")

        # Assume symmetric padding like `pad_hw`
        diff_h = h - orig_h
        diff_w = w - orig_w

        pad_top = diff_h // 2
        pad_bottom = diff_h - pad_top
        pad_left = diff_w // 2
        pad_right = diff_w - pad_left

    # ── Crop away the padding (keep ":" for B, T, C) ────────────────────
    h_end = -pad_bottom if pad_bottom > 0 else None
    w_end = -pad_right if pad_right > 0 else None

    return x[..., pad_top:h_end, pad_left:w_end]
