import os
import json
import torch
import random
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, f1_score
from sklearn.utils.class_weight import compute_class_weight
import pandas as pd
from tqdm import tqdm
import argparse
from utils import TextClassificationDataset, BERTClassifier
from utils import set_seed, load_reasoning_step_training_data, load_reasoning_step_testing_data, evaluate, train

# -----------------------------
# Main
# -----------------------------
def main(args):

    set_seed(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)

    # Load dataset
    train_texts, train_labels, unique_labels = load_reasoning_step_training_data(
        args.train_data_file, args.train_labels_file, args.target_label
    )
    test_texts, test_labels = load_reasoning_step_testing_data(
        args.test_data_file, args.test_labels_file, args.target_label
    )

    tokenizer = BertTokenizer.from_pretrained(args.model_name)
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        train_texts, train_labels, test_size=0.2, random_state=args.seed
    )

    train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, args.max_length)
    val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, args.max_length)
    test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, args.max_length)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BERTClassifier(args.model_name, args.num_classes).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=len(train_loader) * args.epochs
    )

    class_weights = compute_class_weight(
        class_weight='balanced', classes=np.array(unique_labels), y=train_labels
    )
    loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float).to(device))

    results = []

    for epoch in range(args.epochs):
        print(f"\nEpoch {epoch + 1}/{args.epochs}")
        train_loss = train(model, train_loader, optimizer, scheduler, device, loss_fn)
        val_acc, val_f1, val_report = evaluate(model, val_loader, device)

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Accuracy: {val_acc:.4f} | F1: {val_f1:.4f}")
        print(f"Validation Classification Report:\n{val_report}")

        results.append({
            "epoch": epoch,
            "train_loss": train_loss,
            "val_acc": val_acc,
            "val_f1": val_f1,
            "val_report": val_report
        })

    # Test on evaluation
    test_acc, test_f1, test_report = evaluate(model, test_loader, device)
    print(f"\nTest Accuracy: {test_acc:.4f} | F1: {test_f1:.4f}")
    print(f"Test Classification Report:\n{test_report}")

    results.append({
        "dataset": "test",
        "test_acc": test_acc,
        "test_f1": test_f1,
        "test_report": test_report
    })

    # Save results & model
    with open(os.path.join(args.output_dir, "results.json"), "w") as f:
        json.dump(results, f, indent=4)

    torch.save(model.state_dict(), os.path.join(args.output_dir, "bert_model.pth"))

# -----------------------------
# Parameters
# -----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="bert-base-uncased")
    parser.add_argument("--num_classes", type=int, default=2)
    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output_dir", type=str, default="runs")
    parser.add_argument("--train_data_file", type=str, required=True)
    parser.add_argument("--train_labels_file", type=str, required=True)
    parser.add_argument("--test_data_file", type=str, required=True)
    parser.add_argument("--test_labels_file", type=str, required=True)
    parser.add_argument("--target_label", type=str, required=True, help="Target step-type category for binary classification")
    args = parser.parse_args()
    main(args)