import ray
import torch
import torch.nn as nn
import torch.optim as optim
import asyncio
from queue import Queue
from typing import Dict, Tuple, List, Optional
import numpy as np
from worker import WorkerA,WorkerB


@ray.remote
class ServerA:
    def __init__(self, num_workers: int, input_dim: int, embedding_dim: int, channel):
        try:
            self.channel = channel
            self.num_workers = num_workers
            self.input_dim = input_dim
            self.embedding_dim = embedding_dim

            self.workers = [
                WorkerA.remote(worker_id=i, input_dim=input_dim, embedding_dim=embedding_dim)
                for i in range(num_workers)
            ]

            self.batch_to_worker = {}
            self.current_worker = 0

            self.epoch_counter = 0
            self.sync_frequency = 10
        except Exception as e:
            print(f"Error initializing ServerA: {e}")
            raise

    async def process_batch(self, batch_id: str, data: np.ndarray, labels: np.ndarray):
        try:
            worker_id = self.current_worker
            self.current_worker = (self.current_worker + 1) % self.num_workers
            self.batch_to_worker[batch_id] = worker_id

            cache_future = self.workers[worker_id].cache_data.remote(batch_id, data, labels)
            await cache_future

            embedding = None
            for _ in range(100):
                embedding_future = self.channel.receive_embedding.remote(batch_id)
                embedding = await embedding_future
                if embedding is not None:
                    break
                await asyncio.sleep(0.1)

            if embedding is None:
                raise TimeoutError(f"Timeout waiting for embedding for batch {batch_id}")

            process_future = self.workers[worker_id].receive_embedding.remote(batch_id, embedding)
            result = await process_future

            if result is not None and 'gradient_b' in result:
                await self.channel.send_gradient.remote(batch_id, result['gradient_b'])
                del self.batch_to_worker[batch_id]
                return result
            else:
                print(f"Warning: Invalid result from worker for batch {batch_id}")
                return None

        except Exception as e:
            print(f"Error processing batch in ServerA: {e}")
            return None

    async def sync_parameters(self):
        try:
            param_futures = [worker.get_parameters.remote() for worker in self.workers]
            all_params = await asyncio.gather(*param_futures)

            avg_params = self._average_parameters(all_params)

            update_futures = [worker.set_parameters.remote(avg_params) for worker in self.workers]
            await asyncio.gather(*update_futures)

        except Exception as e:
            print(f"Error syncing parameters in ServerA: {e}")
            raise

    def _average_parameters(self, parameters_list):
        avg_params = {}
        for name in parameters_list[0].keys():
            avg_params[name] = np.mean([p[name] for p in parameters_list], axis=0)
        return avg_params

    async def cleanup(self):
        try:
            pending_future = self.channel.get_pending_batches.remote()
            pending_batches = await pending_future

            clear_futures = [self.channel.clear_batch.remote(batch_id) for batch_id in pending_batches]
            await asyncio.gather(*clear_futures)

        except Exception as e:
            print(f"Error cleaning up ServerA: {e}")
            raise


@ray.remote
class ServerB:
    def __init__(self, num_workers: int, input_dim: int, embedding_dim: int, channel):
        try:
            self.channel = channel
            self.num_workers = num_workers
            self.input_dim = input_dim
            self.embedding_dim = embedding_dim


            self.workers = [
                WorkerB.remote(worker_id=i, input_dim=input_dim, embedding_dim=embedding_dim)
                for i in range(num_workers)
            ]


            self.batch_to_worker = {}
            self.current_worker = 0

            self.epoch_counter = 0
            self.sync_frequency = 10
        except Exception as e:
            print(f"Error initializing ServerB: {e}")
            raise

    async def process_batch(self, batch_id: str, data: np.ndarray):
        try:
            worker_id = self.current_worker
            self.current_worker = (self.current_worker + 1) % self.num_workers
            self.batch_to_worker[batch_id] = worker_id

            worker_future = self.workers[worker_id].process_data.remote(batch_id, data)
            result = await worker_future

            await self.channel.send_embedding.remote(batch_id, result['embedding'])

            gradient = None
            for _ in range(100):
                gradient_future = self.channel.receive_gradient.remote(batch_id)
                gradient = await gradient_future
                if gradient is not None:
                    break
                await asyncio.sleep(0.1)

            if gradient is None:
                raise TimeoutError(f"Timeout waiting for gradient for batch {batch_id}")

            gradient_future = self.workers[worker_id].receive_gradient.remote(batch_id, gradient)
            success = await gradient_future

            if success:
                del self.batch_to_worker[batch_id]
                await self.channel.clear_batch.remote(batch_id)

            return success

        except Exception as e:
            print(f"Error processing batch in ServerB: {e}")
            raise

    async def sync_parameters(self):
        try:

            param_futures = [worker.get_parameters.remote() for worker in self.workers]
            all_params = await asyncio.gather(*param_futures)


            avg_params = self._average_parameters(all_params)

  
            update_futures = [worker.set_parameters.remote(avg_params) for worker in self.workers]
            await asyncio.gather(*update_futures)

        except Exception as e:
            print(f"Error syncing parameters in ServerB: {e}")
            raise

    def _average_parameters(self, parameters_list):
        avg_params = {}
        for name in parameters_list[0].keys():
            avg_params[name] = np.mean([p[name] for p in parameters_list], axis=0)
        return avg_params

    async def cleanup(self):
        try:
            pending_future = self.channel.get_pending_batches.remote()
            pending_batches = await pending_future

            clear_futures = [self.channel.clear_batch.remote(batch_id) for batch_id in pending_batches]
            await asyncio.gather(*clear_futures)

        except Exception as e:
            print(f"Error cleaning up ServerB: {e}")
            raise