"""FedSSM Training Script for VQA."""

import os
import sys
import yaml
import argparse
import numpy as np
import time
from pathlib import Path
import torch

sys.path.append('/root/autodl-tmp/vqa_baseline')
sys.path.append('/root/autodl-tmp')

from transformers import LxmertTokenizer
from dataset_lxmert import VQADataset
from model_lora import LXMERTForVQA
from FL_withVQAdata.client import FederatedClient
from FL_withVQAdata.data_partition import create_client_datasets
from FL_withVQAdata.fl_logger import FederatedLogger
from FL_withVQAdata.utils import plot_client_distribution

from fedssm import FedSSMServer, set_seed


def setup_dirs(config):
    """Create output directories."""
    dirs = [config['output']['output_dir'], config['output']['checkpoint_dir'], config['output']['log_dir']]
    for d in dirs:
        Path(d).mkdir(parents=True, exist_ok=True)
    return [Path(d) for d in dirs]


def main(args):
    print("=" * 60)
    print("FedSSM: State Space Model-based Federated Learning")
    print("=" * 60)

    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)

    if args.num_clients:
        config['federated']['num_clients'] = args.num_clients
    if args.num_rounds:
        config['federated']['num_rounds'] = args.num_rounds
    if args.local_epochs:
        config['training']['local_epochs'] = args.local_epochs

    set_seed(config['env']['seed'])
    output_dir, ckpt_dir, log_dir = setup_dirs(config)

    logger = FederatedLogger(log_dir=str(log_dir), experiment_name=config['output']['run_name'])
    logger.log_config(config)

    os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

    logger.info("Loading tokenizer and dataset...")
    tokenizer = LxmertTokenizer.from_pretrained(config['model']['pretrained_model'])

    dataset = VQADataset(
        data_dir=config['data']['data_dir'],
        annotations_dir=config['data']['annotations_dir'],
        images_dir=config['data']['images_dir'],
        split='train', tokenizer=tokenizer,
        max_seq_length=config['model']['max_seq_length'],
        num_visual_features=config['model']['num_visual_features'],
        train_ratio=1.0
    )
    logger.info(f"Dataset: {len(dataset)} samples, {dataset.num_answers} answers")

    logger.info("Partitioning data...")
    train_sets, test_sets = create_client_datasets(
        dataset,
        num_clients=config['federated']['num_clients'],
        partition_strategy=config['federated']['data_partition_strategy'],
        alpha=config['federated'].get('non_iid_alpha', 0.5),
        test_ratio=config['federated'].get('client_test_ratio', 0.2),
        seed=config['env']['seed'],
        min_samples=config['federated'].get('min_samples_per_client', 1000)
    )

    sizes = {i: len(train_sets[i]) + len(test_sets[i]) for i in range(len(train_sets))}
    plot_client_distribution(sizes, logger.exp_dir / "data_distribution.png")

    logger.info("Creating model and server...")
    model = LXMERTForVQA(config=config, num_answers=dataset.num_answers, use_lora=config['lora']['enabled'])
    server = FedSSMServer(model, config)

    clients = {}
    for cid in range(config['federated']['num_clients']):
        clients[cid] = FederatedClient(client_id=cid, model=model, dataset=train_sets[cid], config=config)

    logger.info("Starting training...")
    for rnd in range(config['federated']['num_rounds']):
        t0 = time.time()
        logger.info(f"\n{'='*60}\nRound {rnd+1}/{config['federated']['num_rounds']}\n{'='*60}")

        global_sd = server.get_state_dict()

        # Evaluate all clients
        losses, grads = {}, {}
        for cid, client in clients.items():
            client.set_state_dict(global_sd)
            metrics = client.evaluate(test_sets[cid])
            losses[cid] = metrics['loss']
            profile = server.selector.profiles.get(cid)
            grads[cid] = profile.mean_grad_norm if profile and profile.grad_norm_history else metrics['loss']

        # Select clients
        selected, info = server.select_clients(clients, rnd, losses, grads)
        logger.info(f"Surprise: {info['surprise']:.4f}, Selected: {len(selected)} clients")

        # Local training
        sds, weights, round_losses = [], [], {}
        for cid in selected:
            clients[cid].set_state_dict(global_sd)
            stats = clients[cid].train()
            sds.append(clients[cid].get_state_dict())
            weights.append(stats['samples'])
            round_losses[cid] = stats['loss']
            grad_est = max(0.01, abs(stats['loss'] - losses.get(cid, stats['loss'])))
            server.update_client_stats(cid, stats['loss'], grad_est, rnd)
            logger.info(f"  Client {cid}: loss={stats['loss']:.4f}, acc={stats['accuracy']:.4f}")

        # Aggregate
        agg_sd = server.aggregate(sds, weights, selected, round_losses)
        server.set_state_dict(agg_sd)

        # Evaluate
        if (rnd + 1) % config['evaluation']['eval_every_n_rounds'] == 0:
            test_accs = []
            eval_sd = server.get_state_dict()
            for cid, client in clients.items():
                client.set_state_dict(eval_sd)
                m = client.evaluate(test_sets[cid])
                test_accs.append(m['accuracy'])
            acc = np.mean(test_accs)
            logger.info(f"Test Accuracy: {acc:.4f}")
            if acc > server.best_acc:
                server.best_acc = acc
                server.save(str(ckpt_dir / "best.pt"))

        logger.info(f"Round time: {time.time()-t0:.1f}s")

    server.save(str(ckpt_dir / "final.pt"))
    logger.info(f"\nTraining complete. Best acc: {server.best_acc:.4f}")
    logger.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='configs/default.yaml')
    parser.add_argument('--num_clients', type=int, default=None)
    parser.add_argument('--num_rounds', type=int, default=None)
    parser.add_argument('--local_epochs', type=int, default=None)
    main(parser.parse_args())
