import os
import json
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, roc_auc_score
import numpy as np
from collections import Counter


class PairDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        ei, ej, label, weight = self.pairs[idx]
        ei = torch.tensor(ei, dtype=torch.float32)
        ej = torch.tensor(ej, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)
        weight = torch.tensor(weight, dtype=torch.float32)
        diff = ei - ej
        abs_diff = torch.abs(diff)
        x = torch.cat([ei, ej, diff, abs_diff], dim=0)
        return x, label, weight


class RewardModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.mlp(x).squeeze(-1)


def build_pairs(data):
    pairs = []
    n = len(data)
    for i in range(n):
        for j in range(i + 1, n):
            di, dj = data[i], data[j]
            ei, ej = di["embedding"], dj["embedding"]
            if di["irt"] > dj["irt"]:
                pairs.append((ei, ej, 1, abs(di["irt"] - dj["irt"])))
                pairs.append((ej, ei, 0, abs(di["irt"] - dj["irt"])))
            elif di["irt"] < dj["irt"]:
                pairs.append((ej, ei, 1, abs(di["irt"] - dj["irt"])))
                pairs.append((ei, ej, 0, abs(di["irt"] - dj["irt"])))
    return pairs


def cross_validate(data, all_pairs, embedding_dim, folds=5, epochs=5, batch_size=64, lr=1e-3, device_str=None):
    device = torch.device(device_str or ("cuda" if torch.cuda.is_available() else "cpu"))
    print("Using device:", device)
    kf = KFold(n_splits=folds, shuffle=True, random_state=42)
    acc_scores, auc_scores = [], []

    idx_list = list(range(len(data)))
    for fold, (train_idx, val_idx) in enumerate(kf.split(idx_list)):
        print(f"\n==== Fold {fold + 1} ====")
        train_data = [data[i] for i in train_idx]
        val_data = [data[i] for i in val_idx]

        train_pairs, val_pairs = [], []
        topic2train, topic2val = {}, {}
        for d in train_data:
            topic2train.setdefault(d["sub_topic"], []).append(d)
        for d in val_data:
            topic2val.setdefault(d["sub_topic"], []).append(d)
        for group in topic2train.values():
            if len(group) > 1:
                train_pairs.extend(build_pairs(group))
        for group in topic2val.values():
            if len(group) > 1:
                val_pairs.extend(build_pairs(group))

        print(f"Train pairs: {len(train_pairs)}, Val pairs: {len(val_pairs)}")
        train_loader = DataLoader(PairDataset(train_pairs), batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(PairDataset(val_pairs), batch_size=batch_size, shuffle=False)

        model = RewardModel(input_dim=embedding_dim * 4).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.BCELoss(reduction="none")

        for epoch in range(epochs):
            model.train()
            total_loss = 0.0
            for x, label, weight in train_loader:
                x, label, weight = x.to(device), label.to(device), weight.to(device)
                optimizer.zero_grad()
                prob = model(x)
                loss = criterion(prob, label)
                loss = (loss * torch.sqrt(weight)).mean()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print(f"Epoch {epoch + 1}/{epochs}, Train Loss={total_loss / max(1, len(train_loader)):.4f}")

        model.eval()
        y_true, y_pred, y_prob = [], [], []
        with torch.no_grad():
            for x, label, weight in val_loader:
                x, label = x.to(device), label.to(device)
                prob = model(x)
                pred = (prob > 0.5).float()
                y_true.extend(label.cpu().tolist())
                y_pred.extend(pred.cpu().tolist())
                y_prob.extend(prob.cpu().tolist())

        if len(y_true) > 0:
            acc = accuracy_score(y_true, y_pred)
            try:
                auc = roc_auc_score(y_true, y_prob)
            except ValueError:
                auc = np.nan
        else:
            acc, auc = np.nan, np.nan

        acc_scores.append(acc)
        auc_scores.append(auc)
        print(f"Fold {fold + 1}: Accuracy={acc:.4f}, AUC={auc:.4f}")

    print("\n==== Cross Validation Result ====")
    print(f"Average Accuracy: {np.nanmean(acc_scores):.4f}")
    print(f"Average AUC: {np.nanmean(auc_scores):.4f}")


def train_and_save_final_model(all_pairs, embedding_dim, model_path="reward_model.pth", epochs=8, batch_size=64, lr=1e-3, device_str=None):
    print("\n==== Training Final Model on All Data ====")
    device = torch.device(device_str or ("cuda" if torch.cuda.is_available() else "cpu"))
    print("Using device:", device)

    train_loader = DataLoader(PairDataset(all_pairs), batch_size=batch_size, shuffle=True)
    model = RewardModel(input_dim=embedding_dim * 4).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCELoss(reduction="none")

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for x, label, weight in train_loader:
            x, label, weight = x.to(device), label.to(device), weight.to(device)
            optimizer.zero_grad()
            prob = model(x)
            loss = criterion(prob, label)
            loss = (loss * torch.sqrt(weight)).mean()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Final Train Loss={total_loss / max(1, len(train_loader)):.4f}")

    torch.save(model.state_dict(), model_path)
    print(f"Saved final model to {model_path}")
    return model


def predict_preference(model, device, e1, e2):
    model.eval()
    ei = torch.tensor(e1, dtype=torch.float32)
    ej = torch.tensor(e2, dtype=torch.float32)
    diff = ei - ej
    abs_diff = torch.abs(diff)
    x = torch.cat([ei, ej, diff, abs_diff], dim=0).unsqueeze(0).to(device)
    with torch.no_grad():
        probability = model(x)
    return probability.item()


def main(args):
    data_path = args.data_path or os.environ.get("DATA_PATH", "")
    model_out = args.model_path or os.environ.get("MODEL_PATH", "reward_model.pth")
    folds = args.folds if args.folds is not None else int(os.environ.get("CV_FOLDS", "5"))
    cv_epochs = args.cv_epochs if args.cv_epochs is not None else int(os.environ.get("CV_EPOCHS", "8"))
    train_epochs = args.train_epochs if args.train_epochs is not None else int(os.environ.get("TRAIN_EPOCHS", "8"))
    batch_size = args.batch_size if args.batch_size is not None else int(os.environ.get("BATCH_SIZE", "64"))
    lr = args.lr if args.lr is not None else float(os.environ.get("LR", "1e-3"))
    device_str = args.device or os.environ.get("DEVICE", None)

    if not data_path or not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for item in data:
        topic = item.get("topic", "")
        item["sub_topic"] = topic.split("->")[-1].strip() if isinstance(topic, str) else "unknown"

    topic2data = {}
    for d in data:
        topic2data.setdefault(d["sub_topic"], []).append(d)

    all_pairs = []
    for group in topic2data.values():
        if len(group) > 1:
            all_pairs.extend(build_pairs(group))

    print(f"Total pairs: {len(all_pairs)}")
    embedding_dim = len(data[0]["embedding"])

    labels = [p[2] for p in all_pairs]
    cnt = Counter(labels)
    print("Label distribution:")
    print(f"r=0: {cnt.get(0, 0)}")
    print(f"r=1: {cnt.get(1, 0)}")

    cross_validate(data, all_pairs, embedding_dim, folds=folds, epochs=cv_epochs, batch_size=batch_size, lr=lr, device_str=device_str)

    train_and_save_final_model(all_pairs, embedding_dim, model_path=model_out, epochs=train_epochs, batch_size=batch_size, lr=lr, device_str=device_str)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path")
    parser.add_argument("--model-path")
    parser.add_argument("--folds", type=int)
    parser.add_argument("--cv-epochs", type=int)
    parser.add_argument("--train-epochs", type=int)
    parser.add_argument("--batch-size", type=int)
    parser.add_argument("--lr", type=float)
    parser.add_argument("--device")
    args = parser.parse_args()
    main(args)
