# main.py
import ray
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from typing import Dict
import time
from worker import WorkerA,WorkerB
from server import ServerA,ServerB
from channel import Channel
import sys
from typing import Dict, List, Tuple
import asyncio
from collections import defaultdict
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,4,6,7"

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

import ray
import asyncio
import numpy as np
from collections import defaultdict
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from typing import Dict, List, Tuple


async def train_federated(
        num_epochs: int,
        train_data_a: np.ndarray,
        train_data_b: np.ndarray,
        train_labels: np.ndarray,
        test_data_a: np.ndarray,
        test_data_b: np.ndarray,
        test_labels: np.ndarray,
        num_workers: int = 4,
        batch_size: int = 32,
        global_update_frequency: int = 5,
        patience: int = 5
):
    ray.init(num_cpus=103)

    input_dim_a = train_data_a.shape[1]
    input_dim_b = train_data_b.shape[1]
    embedding_dim = 32

    channel = Channel.remote()
    server_a = ServerA.remote(num_workers, input_dim_a, embedding_dim, channel)
    server_b = ServerB.remote(num_workers, input_dim_b, embedding_dim, channel)

    num_samples = len(train_labels)
    worker_batches = []

    for i in range(0, num_samples, batch_size):
        batch_end = min(i + batch_size, num_samples)
        worker_batches.append({
            'batch_id': f"batch_{i // batch_size}",
            'start': i,
            'end': batch_end
        })

    best_metrics = {
        'accuracy': 0,
        'precision': 0,
        'recall': 0,
        'f1': 0
    }
    no_improvement = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch}")

        async def train_epoch():
            np.random.shuffle(worker_batches)

            batch_futures = []
            for batch in worker_batches:
                batch_id = batch['batch_id']
                start, end = batch['start'], batch['end']

                batch_data_a = train_data_a[start:end]
                batch_data_b = train_data_b[start:end]
                batch_labels = train_labels[start:end]

                embedding_future = server_b.process_batch.remote(batch_id, batch_data_b)
                process_future = server_a.process_batch.remote(
                    batch_id, batch_data_a, batch_labels
                )
                batch_futures.extend([embedding_future, process_future])

            await asyncio.gather(*batch_futures)

        await train_epoch()

        if epoch > 0 and epoch % global_update_frequency == 0:
            print("Performing global model update...")
            sync_futures = [
                server_a.sync_parameters.remote(),
                server_b.sync_parameters.remote()
            ]
            await asyncio.gather(*sync_futures)

            await channel.clear_all.remote()

        if epoch % 5 == 0:
            test_batch_id = "test_batch"

            test_futures = [
                server_b.process_batch.remote(test_batch_id, test_data_b),
                server_a.process_batch.remote(test_batch_id, test_data_a, test_labels)
            ]

            test_results = await asyncio.gather(*test_futures)
            result = test_results[1]  

            if result is not None and 'predictions' in result:
                metrics = {
                    'accuracy': accuracy_score(test_labels, result['predictions']),
                    'precision': precision_score(test_labels, result['predictions'], zero_division=1),
                    'recall': recall_score(test_labels, result['predictions'], zero_division=1),
                    'f1': f1_score(test_labels, result['predictions'], zero_division=1)
                }

                print(f"Test Metrics: Accuracy: {metrics['accuracy']:.4f}, "
                      f"Precision: {metrics['precision']:.4f}, "
                      f"Recall: {metrics['recall']:.4f}, "
                      f"F1: {metrics['f1']:.4f}")

                if metrics['accuracy'] >= 0.97:
                    print(f"\nReached target accuracy of 97% after {epoch} epochs")
                    best_metrics = metrics
                    break

                if metrics['accuracy'] > best_metrics['accuracy']:
                    best_metrics = metrics
                    no_improvement = 0
                else:
                    no_improvement += 1
                    if no_improvement >= patience:
                        print(f"Early stopping triggered after {epoch} epochs")
                        break

    cleanup_futures = [
        server_a.cleanup.remote(),
        server_b.cleanup.remote()
    ]
    await asyncio.gather(*cleanup_futures)

    ray.shutdown()
    return best_metrics


if __name__ == "__main__":
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    sys.path.append('..')
    from data.genvfldataset import SyntheticVFLDataset

    dataset = SyntheticVFLDataset()
    train_x_a, train_y = dataset.get_train_data_for_a()
    train_x_b = dataset.get_train_data_for_b()
    test_x_a, test_y = dataset.get_test_data_for_a()
    test_x_b = dataset.get_test_data_for_b()

    train_x_a = train_x_a.numpy()
    train_x_b = train_x_b.numpy()
    train_y = train_y.numpy()
    test_x_a = test_x_a.numpy()
    test_x_b = test_x_b.numpy()
    test_y = test_y.numpy()

    start_time = time.time()
    best_metrics = asyncio.run(train_federated(
        num_epochs=500,
        train_data_a=train_x_a,
        train_data_b=train_x_b,
        train_labels=train_y,
        test_data_a=test_x_a,
        test_data_b=test_x_b,
        test_labels=test_y,
        num_workers=5,
        batch_size=32,
        global_update_frequency=5
    ))

    end_time = time.time()
    training_time = end_time - start_time
    print(f"\nTraining time: {training_time:.2f} seconds")
    print("\nTraining Complete!")
    print(f"Best Test Metrics:")
    print(f"Accuracy: {best_metrics['accuracy']:.4f}")
    print(f"Precision: {best_metrics['precision']:.4f}")
    print(f"Recall: {best_metrics['recall']:.4f}")
    print(f"F1 Score: {best_metrics['f1']:.4f}")