import ray
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, Optional, List




@ray.remote
class WorkerA:
    def __init__(self, worker_id: int, input_dim: int, embedding_dim: int):
        self.worker_id = worker_id
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.embedding_dim = embedding_dim
        self.bottom_model = self._create_bottom_model(input_dim, embedding_dim)
        self.top_model = self._create_top_model(embedding_dim * 2)

        self.bottom_optimizer = optim.Adam(self.bottom_model.parameters())
        self.top_optimizer = optim.Adam(self.top_model.parameters())
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        self.embedding_cache = {}
        self.data_cache = {}
        self.processing_batches = set()  

    def _create_bottom_model(self, input_dim: int, embedding_dim: int):
        model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, embedding_dim),
            nn.ReLU()
        ).to(self.device)
        return model

    def _create_top_model(self, embedding_dim: int):
        model = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, 2),
            nn.LogSoftmax(dim=1)
        ).to(self.device)
        return model

    async def cache_data(self, batch_id: int, data, labels):
        try:
            self.data_cache[batch_id] = (data, labels)
            self.processing_batches.add(batch_id)
            return await self._try_process_batch(batch_id)
        except Exception as e:
            print(f"Error in cache_data for batch {batch_id}: {e}")
            return None

    async def receive_embedding(self, batch_id: int, embedding_b):
        try:
            if batch_id not in self.processing_batches:
                print(f"Warning: Received embedding for unknown batch {batch_id}")
                return None

            self.embedding_cache[batch_id] = embedding_b
            return await self._try_process_batch(batch_id)
        except Exception as e:
            print(f"Error in receive_embedding for batch {batch_id}: {e}")
            return None

    async def _try_process_batch(self, batch_id: int):
        try:
            if batch_id not in self.processing_batches:
                return None

            if batch_id in self.data_cache and batch_id in self.embedding_cache:
                data, labels = self.data_cache[batch_id]
                embedding_b = self.embedding_cache[batch_id]

                result = await self.process_batch(batch_id, data, labels, embedding_b)


                if result is not None:
                    del self.data_cache[batch_id]
                    del self.embedding_cache[batch_id]
                    self.processing_batches.remove(batch_id)

                return result
            return None

        except Exception as e:
            print(f"Error in _try_process_batch for batch {batch_id}: {e}")
            self._cleanup_batch(batch_id)
            return None

    def _cleanup_batch(self, batch_id: int):
        self.data_cache.pop(batch_id, None)
        self.embedding_cache.pop(batch_id, None)
        self.processing_batches.discard(batch_id)

    async def process_batch(self, batch_id: str, data, labels, worker_b_embedding):
        try:
            data = torch.tensor(data, dtype=torch.float32).to(self.device)
            labels = torch.tensor(labels, dtype=torch.long).to(self.device)
            worker_b_embedding = torch.tensor(worker_b_embedding, dtype=torch.float32).to(self.device)

            # Bottom model forward
            self.bottom_optimizer.zero_grad()
            embedding_a = self.bottom_model(data)

            if embedding_a.shape[1] != self.embedding_dim:
                print(f"Warning: Unexpected embedding_a shape: {embedding_a.shape}")
                return None

            combined_embedding = torch.cat([embedding_a, worker_b_embedding], dim=1)
            combined_embedding.retain_grad()

            # Top model forward
            self.top_optimizer.zero_grad()
            output = self.top_model(combined_embedding)

            loss = self.criterion(output, labels)

            # Backward pass
            loss.backward()

            if combined_embedding.grad is None:
                print(f"Warning: No gradient for combined_embedding in batch {batch_id}")
                return None

            gradient_b = combined_embedding.grad[:, self.embedding_dim:].detach().cpu().numpy()

            self.bottom_optimizer.step()
            self.top_optimizer.step()

            return {
                'batch_id': batch_id,
                'gradient_b': gradient_b,
                'loss': loss.item(),
                'predictions': output.argmax(dim=1).detach().cpu().numpy()
            }

        except Exception as e:
            print(f"Error processing batch {batch_id} in WorkerA: {e}")
            return None

    async def get_parameters(self):
        try:
            params = {}
            for name, param in self.bottom_model.named_parameters():
                params[f'bottom.{name}'] = param.detach().cpu().numpy()

         
            for name, param in self.top_model.named_parameters():
                params[f'top.{name}'] = param.detach().cpu().numpy()

            return params
        except Exception as e:
            print(f"Error in WorkerA get_parameters: {e}")
            return None

    async def set_parameters(self, parameters):
        try:
            for name, param in self.bottom_model.named_parameters():
                if f'bottom.{name}' in parameters:
                    param_tensor = torch.tensor(
                        parameters[f'bottom.{name}'],
                        device=self.device
                    )
                    param.data.copy_(param_tensor)

            for name, param in self.top_model.named_parameters():
                if f'top.{name}' in parameters:
                    param_tensor = torch.tensor(
                        parameters[f'top.{name}'],
                        device=self.device
                    )
                    param.data.copy_(param_tensor)

            return True
        except Exception as e:
            print(f"Error in WorkerA set_parameters: {e}")
            return False

@ray.remote
class WorkerB:
    def __init__(self, worker_id: int, input_dim: int, embedding_dim: int):
        self.worker_id = worker_id
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.bottom_model = self._create_bottom_model(input_dim, embedding_dim)
        self.optimizer = optim.Adam(self.bottom_model.parameters())

        self.data_cache = {}  # {batch_id: data}
        self.pending_batches = set()  

    def _create_bottom_model(self, input_dim: int, embedding_dim: int):
        model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, embedding_dim),
            nn.ReLU()
        ).to(self.device)
        return model

    def process_data(self, batch_id: int, data):
        self.data_cache[batch_id] = data

        data = torch.tensor(data).to(self.device)
        self.bottom_model.train()
        with torch.no_grad():
            embedding = self.bottom_model(data)

        self.pending_batches.add(batch_id)

        return {
            'batch_id': batch_id,
            'embedding': embedding.cpu().numpy()
        }

    def receive_gradient(self, batch_id: int, gradient):
        if batch_id not in self.pending_batches:
            return False

        if batch_id not in self.data_cache:
            return False

        data = self.data_cache[batch_id]
        data = torch.tensor(data).to(self.device)
        gradient = torch.tensor(gradient).to(self.device)

        self.optimizer.zero_grad()
        embedding = self.bottom_model(data)
        embedding.backward(gradient)
        self.optimizer.step()

        self.pending_batches.remove(batch_id)
        del self.data_cache[batch_id]

        return True


    async def get_parameters(self):
        try:
            params = {}
            for name, param in self.bottom_model.named_parameters():
                params[f'bottom.{name}'] = param.detach().cpu().numpy()
            return params
        except Exception as e:
            print(f"Error in WorkerB get_parameters: {e}")
            return None

    async def set_parameters(self, parameters):
        try:
            for name, param in self.bottom_model.named_parameters():
                if f'bottom.{name}' in parameters:
                    param_tensor = torch.tensor(
                        parameters[f'bottom.{name}'],
                        device=self.device
                    )
                    param.data.copy_(param_tensor)
            return True
        except Exception as e:
            print(f"Error in WorkerB set_parameters: {e}")
            return False