import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import json
import numpy as np
from sklearn.model_selection import train_test_split
import argparse
import os

def augment_sequence(score_seq, label, min_ratio=0.3, stride=1):
    n = len(score_seq)
    min_len = max(1, int(n * min_ratio))
    augmented = []

    for length in range(min_len, n + 1, stride):
        for start in range(0, n - length + 1):
            sub_seq = score_seq[start:start + length]
            augmented.append((np.array(sub_seq, dtype=np.float32), label))
    return augmented

def load_json_score_data_with_augmentation(path, augment=True):
    with open(path, 'r') as f:
        raw_data = json.load(f)

    data = []
    for item in raw_data:
        label = item["label"]
        score_sequence = []
        for block in item["per_block_scores"]:
            score_sequence.append(block["model_a_score"])
            score_sequence.append(block["model_b_score"])

        if augment:
            augmented = augment_sequence(score_sequence, label)
            data.extend(augmented)
        else:
            data.append((np.array(score_sequence, dtype=np.float32), label))
    return data

class ScoreSetDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scores, label = self.data[idx]
        return torch.tensor(scores).unsqueeze(-1), torch.tensor(label).float()

def collate_fn(batch):
    scores, labels = zip(*batch)
    padded_scores = pad_sequence(scores, batch_first=True, padding_value=0.0)
    lengths = torch.tensor([len(s) for s in scores])
    labels = torch.stack(labels)
    return padded_scores, lengths, labels

class DeepSetClassifier(nn.Module):
    def __init__(self, hidden_dim=512):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, lengths):
        phi_x = self.phi(x)             
        mask = torch.arange(x.size(1)).unsqueeze(0).to(x.device) < lengths.unsqueeze(1)
        mask = mask.unsqueeze(-1)       
        phi_x = phi_x * mask            
        agg = phi_x.sum(dim=1) / lengths.unsqueeze(-1) 
        out = self.rho(agg)
        return torch.sigmoid(out).squeeze(-1)


def train_model(model, train_loader, val_loader, epochs=10, lr=1e-3, save_path="best_model.pt"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    loss_fn = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    best_acc = 0.0
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        step = 0 
        for x, lengths, labels in train_loader:
            x, lengths, labels = x.to(device), lengths.to(device), labels.to(device)
            preds = model(x, lengths)
            loss = loss_fn(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            step += 1
            if step % 100 == 0:
                print(f"  Step {step}, Current Loss: {loss.item():.4f}")


        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, lengths, labels in val_loader:
                x, lengths, labels = x.to(device), lengths.to(device), labels.to(device)
                preds = model(x, lengths)
                predicted = (preds > 0.5).float()
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
        acc = correct / total
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, Val Acc: {acc:.4f}")


        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best model to {save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="./merged_train_domain_length.json", help="Path to your JSON file")
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--hidden_dim", type=int, default=256)
    parser.add_argument("--save_path", type=str, default="DeepSet_auged.pt")
    args = parser.parse_args()


    full_data = load_json_score_data_with_augmentation(args.data_path, augment=False)

    indices = list(range(len(full_data)))
    train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
    train_raw = [full_data[i] for i in train_idx]
    val_raw = [full_data[i] for i in val_idx]

    train_data = []
    for score_seq, label in train_raw:
        train_data.extend(augment_sequence(score_seq, label))

    val_data = []
    for score_seq, label in val_raw:
        val_data.extend(augment_sequence(score_seq, label))

    train_dataset = ScoreSetDataset(train_data)
    val_dataset = ScoreSetDataset(val_data)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn)

    model = DeepSetClassifier(hidden_dim=args.hidden_dim)
    train_model(model, train_loader, val_loader, epochs=args.epochs, save_path=args.save_path)

    # python train_deepsets_with_aug.py --data_path your_data.json --epochs 20
