import sys
import time
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch

from plotting_utils import apply_plot_style

COMPARE_DIR = Path(__file__).resolve().parents[1]
if str(COMPARE_DIR) not in sys.path:
    sys.path.insert(0, str(COMPARE_DIR))

from methods.strict_fptt import StrictFPTTClassifier  # noqa: E402

# =================== 配置 =================== #
TRAIN_LIMIT = 3000
TEST_LIMIT = 500
BATCH_SIZE = 64
EPOCHS = 30
HIDDEN_SIZE = 128
LEARNING_RATE = 1e-3
SEED = 42

INPUT_SIZE = 28      # 每个时间步特征维度（MNIST 的一行）
TIME_STEPS = 28      # 序列长度（MNIST 的 28 行）
OUTPUT_SIZE = 10

apply_plot_style()


# =================== MNIST 数据加载 =================== #
def load_mnist_data(train_limit, test_limit):
    # 直接用 torchvision 自带的 MNIST，只拿前 train_limit/test_limit 个样本
    train_set = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=None
    )
    test_set = torchvision.datasets.MNIST(
        root="./data", train=False, download=True, transform=None
    )

    x_train = train_set.data[:train_limit].numpy().astype(np.float32) / 255.0
    y_train = train_set.targets[:train_limit].numpy()
    x_test = test_set.data[:test_limit].numpy().astype(np.float32) / 255.0
    y_test = test_set.targets[:test_limit].numpy()

    # (N, 28, 28) -> (N, Features, Time)
    # 这里行/列尺寸相同，直接视为 (N, 28, 28)，每一行作为一个时间步
    x_train = x_train.transpose(0, 1, 2)  # (N, 28, 28)
    x_test  = x_test.transpose(0, 1, 2)

    def make_targets(y, steps):
        N = y.shape[0]
        oh = np.zeros((N, OUTPUT_SIZE), dtype=np.float32)
        oh[np.arange(N), y] = 1.0
        # 每个时间步都复制同一个 one-hot label
        return np.repeat(oh[:, :, None], steps, axis=2)  # (N, 10, steps)

    t_train = make_targets(y_train, TIME_STEPS)
    t_test  = make_targets(y_test, TIME_STEPS)
    return x_train, t_train, y_train, x_test, t_test, y_test


# =================== 一些小工具 =================== #
def set_weights(model, params):
    """
    给 FPTT 模型设置确定的初始化参数，方便复现实验。
    strict_fptt.StrictFPTTClassifier 支持这些属性。
    """
    model.W_hh = params["W_hh"].copy()
    model.W_xh = params["W_xh"].copy()
    model.b_h  = params["b_h"].copy()
    model.W_hy = params["W_hy"].copy()
    model.b_y  = params["b_y"].copy()


def get_weights(model):
    return {
        "W_hh": model.W_hh.copy(),
        "W_xh": model.W_xh.copy(),
        "b_h": model.b_h.copy(),
        "W_hy": model.W_hy.copy(),
        "b_y": model.b_y.copy(),
    }


def evaluate(model, x, labels):
    """
    简单的测试：用最后一个时间步的输出做分类。
    model: StrictFPTTClassifier（实现了 forward_cycle）
    x: (N, 28, 28)
    labels: (N,)
    """
    batch_size = 100
    correct = 0
    N = len(x)
    for i in range(0, N, batch_size):
        bx = x[i : i + batch_size]
        bl = labels[i : i + batch_size]
        current_bs = bx.shape[0]

        h0 = np.zeros((HIDDEN_SIZE, current_bs), dtype=np.float32)
        # StrictFPTTBase.forward_cycle:
        #   outputs: list[time_steps] of (C, B)
        #   h_last:  (H, B)
        outs, _ = model.forward_cycle(bx, h0)

        # 兼容 (C,B) / (B,C) 两种约定（当前 FPTT 是 (C,B)）
        last = outs[-1]
        if last.shape[0] == OUTPUT_SIZE:
            # (C, B) -> (B, C)
            logits_last = last.T
        elif last.shape[1] == OUTPUT_SIZE:
            # 已经是 (B, C)
            logits_last = last
        else:
            raise RuntimeError(
                f"Unexpected output shape from forward_cycle: {last.shape}"
            )

        preds = np.argmax(logits_last, axis=1)
        correct += np.sum(preds == bl)

    return correct / N


# =================== 主训练循环（只有 FPTT） =================== #
if __name__ == "__main__":
    # 固定随机种子（可选）
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    x_train, t_train, y_train, x_test, t_test, y_test = load_mnist_data(
        TRAIN_LIMIT, TEST_LIMIT
    )

    # Standard protocol: keep a validation split for epoch selection; report test only once at the end.
    rng_split = np.random.default_rng(SEED)
    all_idx = rng_split.permutation(len(x_train))
    val_size = max(1, int(len(all_idx) * 0.1))
    val_idx = all_idx[:val_size]
    train_idx = all_idx[val_size:]
    x_train_fit, t_train_fit, y_train_fit = x_train[train_idx], t_train[train_idx], y_train[train_idx]
    x_val, y_val = x_train[val_idx], y_train[val_idx]

    rng = np.random.default_rng(SEED)
    init_params = {
        "W_xh": rng.standard_normal((HIDDEN_SIZE, INPUT_SIZE)).astype(np.float32) * 0.1,
        "W_hh": rng.standard_normal((HIDDEN_SIZE, HIDDEN_SIZE)).astype(np.float32) * 0.05,
        "b_h": np.zeros((HIDDEN_SIZE, 1), dtype=np.float32),
        "W_hy": rng.standard_normal((OUTPUT_SIZE, HIDDEN_SIZE)).astype(np.float32) * 0.1,
        "b_y": np.zeros((OUTPUT_SIZE, 1), dtype=np.float32),
    }

    # 只保留一个 FPTT 模型
    model = StrictFPTTClassifier(
        INPUT_SIZE,
        HIDDEN_SIZE,
        OUTPUT_SIZE,
        eta=LEARNING_RATE,
        parts=8,
        clip=5.0,
        alpha=0.5,          # 可以按原论文调参
        beta=0.5,
        rho=0.0,
        warmup_epochs=1,
        label_mode="all",
        use_oracle=False,
        oracle_id="mnist",
        # device="cuda"  # 如需 GPU 可打开
    )

    # 用固定初始化
    set_weights(model, init_params)
    # 关键：让 FPTT 正则器的 shadow / dual 与当前权重对齐
    if hasattr(model, "reset_state_buffers"):
        model.reset_state_buffers()

    model_rng = np.random.default_rng(SEED + 1)

    history_acc = []
    best_val_acc = -float("inf")
    best_epoch = 0
    best_params = None

    print(f"\n[FPTT] Starting {EPOCHS} epochs. Batch: {BATCH_SIZE}, LR: {LEARNING_RATE}\n")

    for epoch in range(EPOCHS):
        st = time.time()
        model.set_epoch(epoch)

        indices = np.arange(len(x_train_fit))
        model_rng.shuffle(indices)

        total_loss = 0.0
        num_batches = 0

        for i in range(0, len(x_train_fit), BATCH_SIZE):
            idx = indices[i : i + BATCH_SIZE]
            bx = x_train_fit[idx]   # (B, 28, 28)
            bt = t_train_fit[idx]   # (B, 10, 28)
            bs = bx.shape[0]

            h0 = np.zeros((HIDDEN_SIZE, bs), dtype=np.float32)

            loss, _ = model.train_batch(bx, bt, h0)
            if not np.isnan(loss):
                total_loss += loss
            num_batches += 1

        avg_loss = total_loss / max(1, num_batches)
        val_acc = evaluate(model, x_val, y_val)
        history_acc.append(val_acc)
        if val_acc > best_val_acc:
            best_val_acc = float(val_acc)
            best_epoch = int(epoch + 1)
            best_params = get_weights(model)

        print(
            f"[FPTT] Epoch {epoch+1:02d} | "
            f"Val Acc: {val_acc:.4f} | Loss: {avg_loss:.3f} | {time.time()-st:.2f}s"
        )

    # 画出 FPTT 随 epoch 的精度曲线
    if best_params is not None:
        set_weights(model, best_params)
        if hasattr(model, "reset_state_buffers"):
            model.reset_state_buffers()

    test_acc = evaluate(model, x_test, y_test)
    print(f"\n[FPTT] Final test_acc={test_acc:.4f} | best_val_acc={best_val_acc:.4f} (epoch={best_epoch:02d})")

    plt.figure(figsize=(8, 5))
    plt.plot(history_acc, marker="o")
    plt.grid(True)
    plt.title("FPTT on Sequential MNIST")
    plt.xlabel("Epoch")
    plt.ylabel("Val Accuracy")
    plt.tight_layout()
    plt.show()
