import argparse
import torch
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer, AdamW

from data_loader import load_federated_datasets
from lora_model import RobertaWithLoRA
from distill import distillation_loss
from utils import set_seed, Logger, evaluate
from client import train_client
from server import aggregate_logits
from config import get_config

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to config file")
    args = parser.parse_args()

    # Load config
    cfg = get_config(args.config)

    # Set seed
    set_seed(cfg.seed)

    # Logger
    logger = Logger(cfg.log_path)
    logger.log("Starting FD-LoRA training...")

    # Tokenizer
    tokenizer = RobertaTokenizer.from_pretrained(cfg.model_name)

    # Load federated datasets
    client_datasets, public_dataset = load_federated_datasets(cfg, tokenizer)

    # Initialize clients
    clients = []
    for i in range(cfg.num_clients):
        model = RobertaWithLoRA(cfg.model_name, lora_rank=cfg.lora_rank, lora_alpha=cfg.lora_alpha)
        model.to(cfg.device)
        optimizer = AdamW(model.parameters(), lr=cfg.lr)
        clients.append({"model": model, "optimizer": optimizer, "dataset": client_datasets[i]})

    # Initialize global logits (None for round 0)
    global_logits = None

    # Start federated training
    for round in range(cfg.num_rounds):
        logger.log(f"Round {round + 1}/{cfg.num_rounds}")
        local_logits = []

        for idx, client in enumerate(clients):
            logger.log(f"Training client {idx}")
            logit = train_client(
                client["model"], 
                client["optimizer"], 
                client["dataset"], 
                public_dataset,
                global_logits,
                cfg
            )
            local_logits.append(logit)

        # Aggregate logits on the server
        global_logits = aggregate_logits(local_logits)

        # Evaluate a reference client
        acc = evaluate(clients[0]["model"], cfg.eval_dataset, tokenizer, cfg.device)
        logger.log(f"Evaluation Accuracy: {acc:.2f}%")

    logger.log("Training complete.")

if __name__ == "__main__":
    main()
