import re
import time
from numba import njit, prange

import pandas as pd
import numpy as np
from typing import List 
import torch
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    log_loss,
    accuracy_score,
    r2_score,
    precision_recall_curve,
)
from sklearn.metrics import auc as area_under_curve

from pandas.tseries import offsets
from pandas.tseries.frequencies import to_offset 

EPS = 1e-5


def printt(s=None):
    if s is None:
        print()
    else:
        print(str(s), end="\t")


def format_time(t):
    return time.strftime("%m%d%H%M%S", time.localtime(t))


def nan_weighted_avg(vals, weights, axis=None):
    assert vals.shape == weights.shape
    vals = vals.copy()
    weights = weights.copy()
    is_valid = np.logical_and(~np.isnan(vals), ~np.isnan(weights))
    if not np.any(is_valid):
        return np.nan
    weights[~is_valid] = 0
    vals[~is_valid] = 0
    return (vals * weights).sum(axis=axis) / weights.sum(axis=axis)


def z_score_mask(ser, mask):
    ser = ser.copy()
    mean = ser[mask].mean()
    std = ser[mask].std()
    return (ser[mask] - mean) / std


# ──────────────────────────────────────────────────────────────
# Backend‑agnostic helpers (K namespace)
# ──────────────────────────────────────────────────────────────
class K:
    """backend kernel"""

    @staticmethod
    def sum(x, axis=0, keepdims=True):
        if isinstance(x, np.ndarray):
            return x.sum(axis=axis, keepdims=keepdims)
        if isinstance(x, torch.Tensor):
            return x.sum(dim=axis, keepdim=keepdims)
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def clip(x, min_val, max_val):
        if isinstance(x, np.ndarray):
            return np.clip(x, min_val, max_val)
        if isinstance(x, torch.Tensor):
            return torch.clamp(x, min_val, max_val)
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def mean(x, axis=0, keepdims=True):
        if isinstance(x, np.ndarray):
            return x.mean(axis=axis, keepdims=keepdims)
        if isinstance(x, torch.Tensor):
            return x.mean(dim=axis, keepdim=keepdims)
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def seq_mean(x, keepdims=True):
        if isinstance(x, torch.Tensor):
            return x.mean()
        if isinstance(x, np.ndarray):
            return x.mean()
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def std(x, axis=0, keepdims=True):
        if isinstance(x, np.ndarray):
            return x.std(axis=axis, keepdims=keepdims)
        if isinstance(x, torch.Tensor):
            return x.std(dim=axis, unbiased=False, keepdim=keepdims)
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def median(x, axis=0, keepdims=True):
        # numpy averages when size is even; pytorch returns lower of middle pair
        if isinstance(x, np.ndarray):
            return np.median(x, axis=axis, keepdims=keepdims)
        if isinstance(x, torch.Tensor):
            return torch.median(x, dim=axis, keepdim=keepdims)[0]
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def shape(x):
        if isinstance(x, np.ndarray):
            return x.shape
        if isinstance(x, torch.Tensor):
            return list(x.shape)
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def cast(x, dtype="float"):
        if isinstance(x, np.ndarray):
            return x.astype(dtype)
        if isinstance(x, torch.Tensor):
            return x.type(getattr(torch, dtype))
        raise NotImplementedError("unsupported data type %s" % type(x))

    @staticmethod
    def maximum(x, y):
        if isinstance(x, np.ndarray) or isinstance(y, np.ndarray):
            return np.minimum(x, y)
        if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
            return torch.max(x, y)
        if isinstance(x, torch.Tensor):
            return torch.clamp(x, max=y)
        if isinstance(y, torch.Tensor):
            return torch.clamp(y, max=x)
        raise NotImplementedError("unsupported data type %s" % type(x))

    # --- sklearn‑style metrics (AUC, AP, etc.) elided for brevity ---
    # full implementation identical to original code supplied by the user


# Dynamically add simple NumPy/PyTorch mirrors

def generic_ops(method):
    def wrapper(x, *args):
        if isinstance(x, np.ndarray):
            return getattr(np, method)(x, *args)
        if isinstance(x, torch.Tensor):
            return getattr(torch, method)(x, *args)
        raise NotImplementedError("unsupported data type %s" % type(x))

    return wrapper


for method in [
    "abs",
    "log",
    "sqrt",
    "exp",
    "log1p",
    "tanh",
    "cosh",
    "squeeze",
    "reshape",
    "zeros_like",
]:
    setattr(K, method, staticmethod(generic_ops(method)))

# ──────────────────────────────────────────────────────────────
# Statistical helpers
# ──────────────────────────────────────────────────────────────

def zscore(x, axis=0):
    mean = K.mean(x, axis=axis)
    std = K.std(x, axis=axis)
    return (x - mean) / (std + EPS)


def robust_zscore(x, axis=0):
    med = K.median(x, axis=axis)
    mad = K.median(K.abs(x - med), axis=axis)
    x = (x - med) / (mad * 1.4826 + EPS)
    return K.clip(x, -3, 3)


def batch_corr(x, y, axis=0, keepdims=True):
    x = zscore(x, axis=axis)
    y = zscore(y, axis=axis)
    return (x * y).mean()


def robust_batch_corr(x, y, axis=0, keepdims=True):
    x = robust_zscore(x, axis=axis)
    y = robust_zscore(y, axis=axis)
    return batch_corr(x, y)


# --- Fast AUC / AUPRC JIT kernels (unchanged) ---


@njit
def fast_auc(y_true, y_prob):
    mask = np.logical_not(np.isnan(y_true))
    ratio = np.sum(mask) / mask.size
    y_true = np.extract(mask, y_true)
    y_prob = np.extract(mask, y_prob)
    y_true = y_true[np.argsort(y_prob)]
    nfalse = 0
    auc = 0
    n = len(y_true)
    for i in range(n):
        y_i = y_true[i]
        nfalse += 1 - y_i
        auc += y_i * nfalse
    auc /= nfalse * (n - nfalse)
    return auc * ratio, ratio


def fast_auprc(y, p):
    mask = np.logical_not(np.isnan(y))
    ratio = np.sum(mask) / mask.size
    if isinstance(y, torch.Tensor):
        y = y.detach().cpu().numpy()
    if isinstance(p, torch.Tensor):
        p = p.detach().cpu().numpy()
    if isinstance(y, np.ndarray) and isinstance(p, np.ndarray):
        y, p = y.reshape(-1), p.reshape(-1)
        assert len(y) == len(p)
        return average_precision_score(y, p) * ratio, ratio
    if isinstance(y, list) and isinstance(p, list):
        assert len(y) == len(p)
        return average_precision_score(y, p) * ratio, ratio
    raise NotImplementedError("unsupported data type %s or %s" % (type(y), type(p)))


# ──────────────────────────────────────────────────────────────
# Loss & metric implementations (selected)
# ──────────────────────────────────────────────────────────────

def sequence_mse(y_true, y_pred):
    loss = (y_true - y_pred) ** 2
    return K.seq_mean(loss, keepdims=False)


def sequence_mae(y_true, y_pred):
    loss = torch.abs(y_true - y_pred)
    return K.seq_mean(loss, keepdims=False)


def sequence_mase(y_true, y_pred):
    if isinstance(y_true, np.ndarray):
        loss = (y_true - y_pred) ** 2 + np.abs(y_true - y_pred)
    else:
        loss = (y_true - y_pred) ** 2 + (y_true - y_pred).abs()
    return K.seq_mean(loss, keepdims=False)


def single_mase(y_true, y_pred):
    if isinstance(y_true, np.ndarray):
        loss = (y_true - y_pred) ** 2 + np.abs(y_true - y_pred)
    else:
        loss = (y_true - y_pred) ** 2 + (y_true - y_pred).abs()
    return K.mean(loss, keepdims=False)


def single_mae(y_true, y_pred):
    if isinstance(y_true, np.ndarray):
        loss = np.abs(y_true - y_pred)
        return np.nanmean(loss)
    loss = (y_true - y_pred).abs()
    return loss.mean()

def single_mse(y_true, y_pred):
    if isinstance(y_true, np.ndarray):
        loss = (y_true - y_pred) ** 2
        return np.nanmean(loss)
    mask = ~torch.isnan(y_true)
    y_pred = torch.masked_select(y_pred, mask)
    y_true = torch.masked_select(y_true, mask)
    loss = (y_true - y_pred) ** 2
    loss = loss.mean()

    return loss



# NEW METRIC --------------------------------------------------------------------

def cvrmse(y_true, y_pred):
    """Coefficient of Variation of RMSE (percentage).

    Steps
    -----
    1. Compute RMSE per variate.
    2. Divide by the mean absolute true value per variate.
    3. Average across variates and multiply by 100.

    Returns
    -------
    torch.Tensor or float
        Scalar CVRMSE in percentage (higher = worse).
    """
    # numpy → torch conversion if necessary
    if not torch.is_tensor(y_true):
        y_true = torch.from_numpy(y_true)
    if not torch.is_tensor(y_pred):
        y_pred = torch.from_numpy(y_pred)

    # 1) Mean squared error per variate (B, H, D) → (D,)
    sq_err = (y_true - y_pred) ** 2
    mse = K.mean(sq_err, axis=0, keepdims=False)

    # 1) RMSE 계산
    rmse = torch.sqrt(((y_true - y_pred) ** 2).mean(dim=(0, 1)))   # (D,)

    # 2) 정규화용 절댓값 평균
    mean_true = y_true.abs().mean(dim=(0, 1))                      # (D,)

    # 3) CVRMSE per variate
    cvr = rmse / (mean_true + EPS)                                 # (D,)

    # 4) 모든 variate 평균 → 스칼라
    return cvr.mean() * 100                                         # tensor(…)

# ------------------------------------------------------------------------------

# Existing metrics (rrse, mape, etc.) would follow here … (omitted for brevity)

# ──────────────────────────────────────────────────────────────
# Metric dispatchers
# ──────────────────────────────────────────────────────────────

def get_loss_fn(loss_fn):
    if loss_fn in {"mse", "single_mse"}:
        return single_mse
    if loss_fn == "outside_bce":
        return outside_cross_entropy
    if loss_fn == "mase":
        return sequence_mase
    if loss_fn == "mae":
        return single_mae
    if loss_fn.startswith("label"):
        return single_mse
    if loss_fn == "cross_entropy":
        return cross_entropy
    try:
        return eval(loss_fn)
    except Exception:
        pass
    try:
        return neg_wrapper(eval(re.sub("^neg_", "", loss_fn)))
    except Exception:
        raise NotImplementedError("loss function %s is not implemented" % loss_fn)


def get_metric_fn(eval_metric):
    if eval_metric == "corr":
        return neg_wrapper(robust_batch_corr)
    if eval_metric == "mse":
        return single_mse
    if eval_metric == "mae":
        return single_mae
    if eval_metric in ["rse", "rrse"]:
        return rrse
    # NEW METRIC REGISTRATION
    if eval_metric == "cvrmse":
        return cvrmse
    try:
        return eval(eval_metric)
    except Exception:
        pass
    try:
        return neg_wrapper(eval(re.sub("^neg_", "", eval_metric)))
    except Exception:
        raise NotImplementedError("metric function %s is not implemented" % eval_metric)

        
class TimeFeature:
    def __init__(self):
        pass

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        pass

    def __repr__(self):
        return self.__class__.__name__ + "()"


class SecondOfMinute(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.second / 59.0 - 0.5


class MinuteOfHour(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.minute / 59.0 - 0.5


class HourOfDay(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.hour / 23.0 - 0.5


class DayOfWeek(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.dayofweek / 6.0 - 0.5


class DayOfMonth(TimeFeature):
    """Day of month encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.day - 1) / 30.0 - 0.5


class DayOfYear(TimeFeature):
    """Day of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.dayofyear - 1) / 365.0 - 0.5


class MonthOfYear(TimeFeature):
    """Month of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.month - 1) / 11.0 - 0.5


class WeekOfYear(TimeFeature):
    """Week of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.isocalendar().week - 1) / 52.0 - 0.5


def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
    """
    Returns a list of time features that will be appropriate for the given frequency string.
    Parameters
    ----------
    freq_str
        Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
    """

    features_by_offsets = {
        offsets.YearEnd: [],
        offsets.QuarterEnd: [MonthOfYear],
        offsets.MonthEnd: [MonthOfYear],
        offsets.Week: [DayOfMonth, WeekOfYear],
        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Minute: [
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
        offsets.Second: [
            SecondOfMinute,
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
    }

    offset = to_offset(freq_str)

    for offset_type, feature_classes in features_by_offsets.items():
        if isinstance(offset, offset_type):
            return [cls() for cls in feature_classes]

    supported_freq_msg = f"""
    Unsupported frequency {freq_str}
    The following frequencies are supported:
        Y   - yearly
            alias: A
        M   - monthly
        W   - weekly
        D   - daily
        B   - business days
        H   - hourly
        T   - minutely
            alias: min
        S   - secondly
    """
    raise RuntimeError(supported_freq_msg)


def time_features(dates, timeenc=1, freq="h") -> np.ndarray:
    """
    > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0:
    > * m - [month]
    > * w - [month]
    > * d - [month, day, weekday]
    > * b - [month, day, weekday]
    > * h - [month, day, weekday, hour]
    > * t - [month, day, weekday, hour, *minute]
    >
    > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]):
    > * Q - [month]
    > * M - [month]
    > * W - [Day of month, week of year]
    > * D - [Day of week, day of month, day of year]
    > * B - [Day of week, day of month, day of year]
    > * H - [Hour of day, day of week, day of month, day of year]
    > * T - [Minute of hour*, hour of day, day of week, day of month, day of year]
    > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year]

    *minute returns a number from 0-3 corresponding to the 15 minute period it falls into.
    """
    if timeenc == 0:
        dates["month"] = dates.date.apply(lambda row: row.month, 1)
        dates["day"] = dates.date.apply(lambda row: row.day, 1)
        dates["weekday"] = dates.date.apply(lambda row: row.weekday(), 1)
        dates["hour"] = dates.date.apply(lambda row: row.hour, 1)
        dates["minute"] = dates.date.apply(lambda row: row.minute, 1)
        dates["minute"] = dates.minute.map(lambda x: x // 15)
        freq_map = {
            "y": [],
            "m": ["month"],
            "w": ["month"],
            "d": ["month", "day", "weekday"],
            "b": ["month", "day", "weekday"],
            "h": ["month", "day", "weekday", "hour"],
            "t": ["month", "day", "weekday", "hour", "minute"],
        }
        return dates[freq_map[freq.lower()]].values
    else:
        dates = pd.to_datetime(dates.date.values)
        return np.vstack(
            [feat(dates) for feat in time_features_from_frequency_str(freq)]
        ).transpose(1, 0)
    

# 마지막 줄쯤에 추가 (optional but 깔끔)
__all__ = [
    "time_features",
    "time_features_from_frequency_str",
    "TimeFeature",
    "SecondOfMinute", "MinuteOfHour", "HourOfDay",
    "DayOfWeek", "DayOfMonth", "DayOfYear",
    "MonthOfYear", "WeekOfYear",
]