import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from model.modules import *
import copy
import logging

logger = logging.getLogger('GFedCL')

class ModifiedClient(nn.Module):
    """
    Modified GFedCL Client that works with a centralized server discriminator
    
    Main changes:
    1. No local discriminator
    2. Uses server's discriminator for loss computation
    3. Sends encoded samples to server for discriminator training
    4. Can generate encodings without training for the server's discriminator
    """
    def __init__(self, client_id, opt):
        super(ModifiedClient, self).__init__()
        
        # set output format
        np.set_printoptions(suppress=True, precision=6)
        self.client_id = client_id
        self.opt = opt
        self.device = opt.device
        self.batch_size = opt.batch_size

        # visualization
        self.use_g_encode = opt.use_g_encode

        # Initialize the neural networks
        self.netE = FeatureEncoder(opt).to(opt.device)  # Encoder-decoder
        self.netF = PredNet(opt).to(opt.device)        # Predictor for ILI
        self.netG = GNet(opt).to(opt.device)           # Graph embedding generator 
        
        # Note: No local discriminator (netD) anymore
        
        # Initialize weights
        self.__init_weight__(self.netE)
        self.__init_weight__(self.netF)
        self.__init_weight__(self.netG)
        
        # Set up optimizers - only for encoder and predictor
        EF_parameters = list(self.netE.parameters()) + list(self.netF.parameters())
        self.optimizer_EF = optim.Adam(
            EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999)
        )
        
        # Set up learning rate schedulers
        self.lr_scheduler_EF = lr_scheduler.ExponentialLR(
            optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100)
        )

        self.lr_schedulers = [self.lr_scheduler_EF]
        
        # Define loss names for tracking
        self.loss_names = ["E_pred", "E_gan"]
            
        # Initialize tracking variables
        self.relational_graph = None
        self.task_ID = None
        self.server_discriminator = None  # Will hold the server's discriminator for inference
    
    def getId(self): 
        """Return client ID"""
        return self.client_id
        
    def set_server_discriminator(self, discriminator_state_dict):
        """
        Set the server's discriminator for loss computation
        
        Args:
            discriminator_state_dict: State dict of server's discriminator
        """
        if self.server_discriminator is None:
            # Initialize a local copy of the discriminator
            self.server_discriminator = GraphDNet(self.opt).to(self.device)
        
        # Load state dict
        self.server_discriminator.load_state_dict(discriminator_state_dict)
        
        # Set to eval mode since we don't train it locally
        self.server_discriminator.eval()
    
    def learn(self, epoch, task, relational_graphs, dataloader, generate=False):
        """
        Train the client model using relational graph and dataloader
        
        Args:
            epoch: Current epoch number
            task: Current task ID
            relational_graphs: Relational graphs for all tasks
            dataloader: DataLoader containing samples
            generate: Flag to control synthetic sample generation
            
        Returns:
            dict: Loss values and encoded samples for server
        """
        # Ensure we have a server discriminator to compute losses
        if self.server_discriminator is None:
            logger.error(f"Client {self.client_id}: No server discriminator available for training")
            return None
        
        # Set model to training mode
        self.train()
            
        # Store the task ID and relational graph for use in training
        self.epoch = epoch
        self.generate = generate
        self.task_ID = task
        self.relational_graph = relational_graphs
        
        # Initialize loss tracking
        loss_values = {loss: 0 for loss in self.loss_names}
        count = 0
        
        # Lists to collect encoded samples and graph embeddings
        collected_encodings = []
        collected_graph_embeddings = []
        
        # Training loop
        for data in dataloader:
            count += 1
            
            # Set the input data - this will also create synthetic data if needed
            self.__set_input__(data, self.generate)
            
            # Forward pass
            self.__train_forward__()
            
            # Calculate losses and update weights for encoder and predictor
            new_loss_values = self.__optimize__()

            # Track loss values
            for key, loss in new_loss_values.items():
                loss_values[key] += loss
                
            # Collect encoded samples and graph embeddings for server training
            collected_encodings.append(self.e_seq.detach().clone())
            collected_graph_embeddings.append(self.z_seq.detach().clone())

        # Calculate average loss
        if count > 0:
            for key in loss_values.keys():
                loss_values[key] /= count

        # Log progress periodically
        status_msg = f"Client {self.client_id}, Task {task}, Epoch {self.epoch}"
        if self.generate:
            status_msg += " (with synthetic data)"
        status_msg += f": {loss_values}"
        logger.info(status_msg)

        # Apply learning rate decay
        for lr_scheduler in self.lr_schedulers:
            lr_scheduler.step()
            
        # Return loss values and collected data for server training
        return {
            'loss_values': loss_values,
            'encodings': collected_encodings,
            'graph_embeddings': collected_graph_embeddings
        }
    
    def test(self, task_id, dataloader):
        """
        Test the model on a dataset - adapted for ILI regression
        """
        self.eval()
        
        # Track metrics
        total_mse = 0.0
        total_mae = 0.0
        total_samples = 0
        
        # For R² calculation
        all_predictions = []
        all_targets = []
        
        # State-wise metrics
        state_mse = {i: 0.0 for i in range(self.opt.states_per_client)}
        state_mae = {i: 0.0 for i in range(self.opt.states_per_client)}
        state_samples = {i: 0 for i in range(self.opt.states_per_client)}
        
        # Test loop
        with torch.no_grad():
            for data in dataloader:
                # Set input data
                self.__set_input__(data, generate=False, train=False)
                
                # Forward pass
                self.__test_forward__()
                
                # For regression, f_seq contains continuous predictions
                batch_size = self.y_seq.size(0)
                num_states = self.y_seq.size(1)
                
                # Calculate MSE and MAE
                mse = F.mse_loss(self.f_seq, self.y_seq, reduction='none')
                mae = F.l1_loss(self.f_seq, self.y_seq, reduction='none')
                
                # Update total metrics
                total_mse += mse.sum().item()
                total_mae += mae.sum().item()
                total_samples += batch_size * num_states
                
                # Collect predictions and targets for R²
                all_predictions.append(self.f_seq.cpu())
                all_targets.append(self.y_seq.cpu())
                
                # Calculate per-state metrics
                for state_idx in range(num_states):
                    state_mse[state_idx] += mse[:, state_idx].sum().item()
                    state_mae[state_idx] += mae[:, state_idx].sum().item()
                    state_samples[state_idx] += batch_size
        
        # Calculate final metrics
        avg_mse = total_mse / total_samples if total_samples > 0 else float('inf')
        avg_mae = total_mae / total_samples if total_samples > 0 else float('inf')
        avg_rmse = np.sqrt(avg_mse)
        
        # Calculate R²
        if all_predictions:
            all_preds = torch.cat(all_predictions, dim=0).numpy()
            all_targs = torch.cat(all_targets, dim=0).numpy()
            
            all_preds_flat = all_preds.flatten()
            all_targs_flat = all_targs.flatten()
            
            ss_res = np.sum((all_targs_flat - all_preds_flat) ** 2)
            ss_tot = np.sum((all_targs_flat - np.mean(all_targs_flat)) ** 2)
            r2_score = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        else:
            r2_score = 0
        
        logger.info(f"Client {self.client_id}, Task {task_id} Test - "
                f"MSE: {avg_mse:.6f}, MAE: {avg_mae:.6f}, "
                f"RMSE: {avg_rmse:.6f}, R²: {r2_score:.4f}")
        
        return {
            "loss": avg_mse,
            "mse": avg_mse,
            "mae": avg_mae,
            "rmse": avg_rmse,
            "r2": r2_score,
            "acc": r2_score * 100,  # For compatibility with existing code
            #"state_metrics": state_metrics
        }
    
    def __set_input__(self, data, generate=False, train=True):
        """
        Sets the input data for model training/testing.
        
        Args:
            data: Tuple of (inputs, targets) from dataloader
            generate: Whether to generate synthetic samples
            train: Whether in training mode
        """
        # DataLoader returns:
        # data[0] = time series input [batch_size, sequence_length * num_states]
        # data[1] = target values [batch_size, num_states] (discretized)
        
        # Unpack the data
        inputs, targets = data
        
        # Move to device
        self.x_seq = inputs.to(self.device)
        self.y_seq = targets.to(self.device)
        
        # Get batch size from actual data
        self.batch_size = self.x_seq.size(0)
        
        # Create synthetic data if needed (for continual learning)
        if generate and hasattr(self, 'task_ID') and self.task_ID > 0:
            # Create random noise with the same shape as input data
            noise = torch.randn_like(self.x_seq, device=self.device)
            self.x_seq_synthetic = noise

        # Extract client relations from graph if available
        if hasattr(self, 'relational_graph') and hasattr(self, 'task_ID'):
            # Create one-hot encoding for client ID with proper batch size
            one_hot = torch.zeros(self.batch_size, self.opt.num_clients, device=self.device)
            one_hot[:, self.client_id] = 1.0
            self.one_hot_seq = one_hot
            
            try:
                # Get client's row from the relational graph for current task
                if isinstance(self.relational_graph, list) and len(self.relational_graph) > self.task_ID:
                    graph = self.relational_graph[self.task_ID]
                    
                    # Extract client-specific relationships
                    if isinstance(graph, np.ndarray) and graph.shape[0] > self.client_id:
                        # Ensure float32 dtype
                        if graph.dtype != np.float32:
                            graph = graph.astype(np.float32)
                        client_relations = torch.tensor(graph[self.client_id], device=self.device, dtype=torch.float32)
                        # Expand to match batch size
                        self.client_relations = client_relations.unsqueeze(0).expand(self.batch_size, -1)
                    else:
                        # If the graph doesn't have the expected structure, use one-hot encoding
                        self.client_relations = self.one_hot_seq
                else:
                    # If there's no graph for this task, use one-hot encoding
                    self.client_relations = self.one_hot_seq
            except Exception as e:
                logger.error(f"Error accessing relational graph: {str(e)}")
                # Fallback to identity relationships
                self.client_relations = self.one_hot_seq

    def __train_forward__(self):
        """
        Forward pass during training - adapted for ILI time series
        """
        # Determine graph embedding source
        if hasattr(self, 'client_relations'):
            graph_embedding = self.client_relations.clone()
        else:
            # Create graph embedding with proper batch size
            batch_size = self.x_seq.size(0)
            graph_embedding = torch.zeros(batch_size, self.opt.num_clients, device=self.device)
            graph_embedding[:, self.client_id] = 1.0
        
        # Use appropriate input data
        if self.generate and hasattr(self, 'task_ID') and self.task_ID > 0 and hasattr(self, 'x_seq_synthetic'):
            input_data = self.x_seq_synthetic
            self.x_seq_processed = self.x_seq_synthetic
        else:
            input_data = self.x_seq
            self.x_seq_processed = self.x_seq
        
        # Forward pass through networks
        self.z_seq = self.netG(graph_embedding)
        
        # Encode time series data with labels and graph embedding
        self.e_seq = self.netE(input_data, self.y_seq, self.z_seq)
        
        # Predict ILI values for each state
        self.f_seq = self.netF(self.e_seq)
        
        # Now use the server discriminator for inference only
        with torch.no_grad():
            self.d_seq = self.server_discriminator(self.e_seq.clone())

    def __test_forward__(self):
        """
        Forward pass during testing
        """
        # Create graph embedding
        if hasattr(self, 'client_relations'):
            graph_embedding = self.client_relations
        else:
            batch_size = self.x_seq.size(0)
            graph_embedding = torch.zeros(batch_size, self.opt.num_clients, device=self.device)
            graph_embedding[:, self.client_id] = 1.0
            
        # Generate graph embeddings
        self.z_seq = self.netG(graph_embedding)
        
        # Encode the data
        self.e_seq = self.netE(self.x_seq, self.y_seq, self.z_seq)
        
        # Generate predictions (continuous values)
        self.f_seq = self.netF(self.e_seq)

    def __optimize__(self):
        """
        Optimize encoder and predictor models
        
        Returns:
            dict: Loss values
        """
        loss_value = dict()
        
        # We only optimize encoder and predictor, not the discriminator
        self.loss_E, self.loss_E_pred, self.loss_E_gan = self.__loss_EF__()
        self.optimizer_EF.zero_grad()
        self.loss_E.backward()
        self.optimizer_EF.step()
        
        loss_value["E_pred"] = self.loss_E_pred.item()
        loss_value["E_gan"] = self.loss_E_gan.item()
        
        return loss_value
        
    def generate_encodings(self, task, relational_graphs, dataloader, generate=False):
        """
        Generate encodings without training for the server's discriminator
        
        Args:
            task: Task ID
            relational_graphs: Relational graphs for all tasks
            dataloader: DataLoader containing samples
            generate: Whether to use synthetic samples
            
        Returns:
            dict: Encodings and graph embeddings
        """
        # Ensure we have a server discriminator
        if self.server_discriminator is None:
            logger.error(f"Client {self.client_id}: No server discriminator available")
            return {'encodings': [], 'graph_embeddings': []}
        
        # Set model to eval mode to prevent batch norm, dropout effects
        self.eval()
            
        # Store the task ID and relational graph for use
        self.task_ID = task
        self.relational_graph = relational_graphs
        
        # Lists to collect encoded samples and graph embeddings
        collected_encodings = []
        collected_graph_embeddings = []
        
        # Forward pass through dataloader without training
        with torch.no_grad():
            for data in dataloader:
                # Set the input data
                self.__set_input__(data, generate)
                
                # Get batch size from current data
                batch_size = self.x_seq.size(0)
                
                # Forward pass 
                if hasattr(self, 'client_relations'):
                    graph_embedding = self.client_relations.clone()
                else:
                    # Create graph embedding with proper batch size
                    graph_embedding = torch.zeros(batch_size, self.opt.num_clients, device=self.device)
                    graph_embedding[:, self.client_id] = 1.0
                
                # Ensure graph_embedding is float32
                if graph_embedding.dtype != torch.float32:
                    graph_embedding = graph_embedding.float()
                
                # Use appropriate input data
                if generate and hasattr(self, 'task_ID') and self.task_ID > 0 and hasattr(self, 'x_seq_synthetic'):
                    input_data = self.x_seq_synthetic
                else:
                    input_data = self.x_seq
                
                # Forward pass through networks
                self.z_seq = self.netG(graph_embedding)
                self.e_seq = self.netE(input_data, self.y_seq, self.z_seq)
                
                # Collect encoded samples and graph embeddings
                collected_encodings.append(self.e_seq.clone())
                collected_graph_embeddings.append(self.z_seq.clone())
        
        # Return to train mode
        self.train()
        
        # Return the collected data
        return {
            'encodings': collected_encodings,
            'graph_embeddings': collected_graph_embeddings
        }

    def __loss_EF__(self):
        """
        Encoder and predictor loss calculation for regression
        Uses MSE loss instead of NLL loss
        """
        # GAN loss computation (same as before)
        batch_size = self.e_seq.size(0)
        
        # Ensure z_seq matches batch size
        if self.z_seq.size(0) != batch_size:
            if self.z_seq.size(0) == 1:
                target = self.z_seq.expand(batch_size, -1)
            else:
                target = self.z_seq[:batch_size]
        else:
            target = self.z_seq
        
        if self.d_seq.dim() > 2:
            predicted = self.d_seq.reshape(-1, self.d_seq.size(-1))
        else:
            predicted = self.d_seq
        
        # Compute GAN loss as negative MSE
        criterion = nn.MSELoss()
        try:
            loss_E_gan = -criterion(predicted, target)
        except Exception as e:
            logger.error(f"Error computing GAN loss: {str(e)}")
            loss_E_gan = torch.tensor(0.0, device=self.device, requires_grad=True)
        
        # Regression loss using MSE
        # f_seq shape: [batch_size, num_states] (continuous values)
        # y_seq shape: [batch_size, num_states] (continuous targets)
        loss_E_pred = F.mse_loss(self.f_seq, self.y_seq)
        
        # Optionally add L1 loss for better robustness
        loss_E_l1 = F.l1_loss(self.f_seq, self.y_seq) * 0.1  # Weight the L1 loss
        
        # Combined loss
        loss_E = loss_E_gan * self.opt.lambda_gan + loss_E_pred + loss_E_l1
        
        return loss_E, loss_E_pred, loss_E_gan
    
    def get_weights(self):
        """
        Get the model weights as state dictionaries
        
        Returns:
            dict: Model weights
        """
        return {
            'encoder': copy.deepcopy(self.netE.state_dict()),
            'predictor': copy.deepcopy(self.netF.state_dict()),
            'graph_generator': copy.deepcopy(self.netG.state_dict())
        }
    
    def set_weights(self, weights):
        """
        Set the model weights from state dictionaries
        
        Args:
            weights: Dictionary containing model weights
        """
        if 'encoder' in weights:
            self.netE.load_state_dict(weights['encoder'])
        if 'predictor' in weights:
            self.netF.load_state_dict(weights['predictor'])
        if 'graph_generator' in weights and hasattr(self, 'netG'):
            self.netG.load_state_dict(weights['graph_generator'])
    
    def __init_weight__(self, net=None):
        """Initialize weights for the network"""
        if net is None:
            net = self
        for m in net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.constant_(m.bias, val=0)