import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer
import argparse

from trigger_datasets import SteeringDataset
from modeling_trigger import BertRegressor, run_epoch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def train(args):
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load dataset
    dataset = torch.load(args.data)
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Split
    total = len(dataset)
    train_len = int(0.8 * total)
    val_len = int(0.1 * total)
    test_len = total - train_len - val_len
    train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len])

    train_loader = DataLoader(SteeringDataset(train_set, tokenizer), batch_size=16, shuffle=True)
    val_loader = DataLoader(SteeringDataset(val_set, tokenizer), batch_size=16)
    test_loader = DataLoader(SteeringDataset(test_set, tokenizer), batch_size=16)

    model = BertRegressor(output_dim=args.output_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    loss_fn = nn.MSELoss()

    best_val_loss = float("inf")

    for epoch in range(1, args.epochs + 1):
        print(f"\nEpoch {epoch}")

        train_loss, _ = run_epoch(model, train_loader, loss_fn, device, optimizer)
        val_loss, val_cos = run_epoch(model, val_loader, loss_fn, device)

        print(f"Train Loss={train_loss:.4f} | Val Loss={val_loss:.4f} | Val CosSim={val_cos:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), args.save_path)
            print(f"Saved best model to {args.save_path}")

    # Final test
    print("\nFinal Testing...")
    model.load_state_dict(torch.load(args.save_path))
    test_loss, test_cos = run_epoch(model, test_loader, loss_fn, device)
    print(f"Final Test Loss={test_loss:.4f} | Test CosSim={test_cos:.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train BERT regressor for reflection vectors")
    parser.add_argument("--data", type=str, required=True, help="Path to .pt dataset file")
    parser.add_argument("--save_path", type=str, required=True, help="Path to save best model checkpoint")
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--output_dim", type=int, default=4096)
    args = parser.parse_args()

    train(args)
