import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from dataset_forecasting import GasLoadForecastDataset
from main_model import CSDI_Forecasting
import yaml
import argparse

class CovariateConsistencyScorer(nn.Module):
    def __init__(self, pred_len, cov_dim, hidden_dim=64):
        super().__init__()
        input_dim = pred_len + pred_len * cov_dim
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, future_pred, future_cov):
        B = future_pred.size(0)
        x = torch.cat([future_pred, future_cov.reshape(B, -1)], dim=1)
        return self.net(x).squeeze(1)

def construct_scorer_sample(dataset, idx, alpha_schedule, num_steps, is_positive=True, mode="shuffle", alt_idx=None):
    sample = dataset[idx]
    data = sample["observed_data"]  # [seq_len, feat_dim]
    mask = sample["gt_mask"]
    hist_len = mask.shape[0] - np.sum(mask[:, 0] == 0)

    future_target = data[hist_len:, 0]         # [pred_len]
    future_cov = data[hist_len:, 1:]           # [pred_len, cov_dim]

    if is_positive:
        t = np.random.randint(0, num_steps)  # sample t in [0, T)
        alpha_t = alpha_schedule[t]          # scalar ∈ [0, 1]
        sqrt_alpha = np.sqrt(alpha_t)
        sqrt_one_minus_alpha = np.sqrt(1.0 - alpha_t)
        noise = np.random.randn(*future_target.shape).astype(np.float32)
        future_target = sqrt_alpha * future_target + sqrt_one_minus_alpha * noise
        label = 1
    else:
        if mode == "shuffle":
            future_target = np.random.permutation(future_target)
        elif mode == "trend_flip":
            future_target = future_target[::-1].copy()
        elif mode == "cov_mismatch" and alt_idx is not None:
            alt_sample = dataset[alt_idx]
            alt_data = alt_sample["observed_data"]
            alt_mask = alt_sample["gt_mask"]
            alt_hist_len = alt_mask.shape[0] - np.sum(alt_mask[:, 0] == 0)
            future_cov = alt_data[alt_hist_len:, 1:]  # 替换协变量
        label = 0

    return future_target.astype(np.float32), future_cov.astype(np.float32), label

def build_scorer_batch(dataset, batch_size=32, alpha_schedule=None, num_steps=None):
    pred_seqs = []
    cov_seqs = []
    labels = []
    total_len = len(dataset)

    for _ in range(batch_size // 2):
        # 正样本
        idx = np.random.randint(0, total_len)
        pos_pred, pos_cov, pos_label = construct_scorer_sample(dataset, idx, alpha_schedule, num_steps, is_positive=True)
        pred_seqs.append(pos_pred)
        cov_seqs.append(pos_cov)
        labels.append(pos_label)

        # 负样本
        neg_idx = np.random.randint(0, total_len)
        alt_idx = (neg_idx + np.random.randint(1, total_len)) % total_len
        mode = np.random.choice(["shuffle", "trend_flip", "cov_mismatch"])
        neg_pred, neg_cov, neg_label = construct_scorer_sample(dataset, neg_idx, alpha_schedule, num_steps, is_positive=False, mode=mode, alt_idx=alt_idx)
        pred_seqs.append(neg_pred)
        cov_seqs.append(neg_cov)
        labels.append(neg_label)

    return (
        torch.tensor(np.stack(pred_seqs), dtype=torch.float32),
        torch.tensor(np.stack(cov_seqs), dtype=torch.float32),
        torch.tensor(labels, dtype=torch.float32)
    )

def train_scorer_model(
    dataset,
    pred_len,
    cov_dim,
    num_epochs=10,
    batch_size=64,
    lr=1e-3,
    device="cuda" if torch.cuda.is_available() else "cpu",
    print_every=1,
    alpha_schedule=None,
    num_steps=None,
    dataset_name=None,
):
    # 模型初始化
    scorer = CovariateConsistencyScorer(pred_len, cov_dim).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(scorer.parameters(), lr=lr)

    for epoch in range(1, num_epochs + 1):
        scorer.train()
        total_loss = 0
        num_batches = 200  # 可调整或动态设定

        for _ in tqdm(range(num_batches), desc=f"Epoch {epoch}"):
            X_pred, X_cov, Y = build_scorer_batch(dataset, batch_size=batch_size, alpha_schedule=alpha_schedule, num_steps=num_steps)
            X_pred, X_cov, Y = X_pred.to(device), X_cov.to(device), Y.to(device)

            scores = scorer(X_pred, X_cov)  # shape: (B,)
            loss = criterion(scores, Y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / num_batches
        if epoch % print_every == 0:
            print(f"[Epoch {epoch}] Avg Loss: {avg_loss:.4f}")
            save_path = f"scorer_lastest_{dataset_name}.pt"
            torch.save(scorer.state_dict(), save_path)

    return scorer

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Particle Filter Scorer Training")
    parser.add_argument("--data", type=str)
    parser.add_argument("--csvpath", type=str)
    parser.add_argument("--statspath", type=str)
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    csv_path = args.csvpath
    stats_path = args.statspath
    history_len = 168
    pred_len = 24

    path = "config/base_forecasting.yaml"
    with open(path, "r") as f:
        config = yaml.safe_load(f)

    dataset = GasLoadForecastDataset(
        csv_path=csv_path,
        stats_path=stats_path,
        subset='train',
        history_len=history_len,
        pred_len=pred_len,
    )

    diffusion_model  = CSDI_Forecasting(config, device, 3).to(device)
    alpha_schedule = diffusion_model.alpha
    num_steps = diffusion_model.num_steps

    pred_len = dataset.pred_len
    cov_dim = dataset.feature_dim - 1

    scorer_model = train_scorer_model(
        dataset,
        pred_len=pred_len,
        cov_dim=cov_dim,
        num_epochs=400,
        batch_size=64,
        lr=1e-3,
        alpha_schedule=alpha_schedule,
        num_steps = num_steps,
        dataset_name=args.data,
    )
