import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from metrics.abstract_metrics import CrossEntropyMetric, KLDivergenceMetricForAtomDist, KLDivergenceForPrior, FocalLossMetric, BCEWithLogitsMetric
from metrics.abstract_metrics import MSEMetric

import wandb

class LossMonitor:
    def __init__(self, tokenizer):
        self.pos_acc = []
        self.x_acc = []
        self.e_acc = []
        self.adj_acc = []
        self.target_accuracies = {}  # Track accuracies per target
        self.tokenizer = tokenizer
        
    def update_accuracies(self, pred, true, loss_type):
        acc = (pred.argmax(dim=-1) == true.argmax(dim=-1)).float().mean()
        getattr(self, f"{loss_type}_acc").append(acc)
        return acc
    
    def update_target_acc(self, pred, true, loss_name):
        """Track accuracies for individual targets"""
        true_targets = true.argmax(dim=-1)
        pred_targets = pred.argmax(dim=-1)
        
        for target_idx in torch.unique(true_targets):
            mask = (true_targets == target_idx)
            if mask.any():
                target_pred = pred_targets[mask]
                target_true = true_targets[mask]
                accuracy = (target_pred == target_true).float().mean()
                
                # Create a unique key combining loss_name and target_idx
                combined_key = f"{loss_name}-{target_idx.item()}"
                
                if combined_key not in self.target_accuracies:
                    self.target_accuracies[combined_key] = []
                    
                self.target_accuracies[combined_key].append(accuracy.item())

    def log_target_acc(self):
        """Log accuracy histograms and distribution plots per target"""
        if not wandb.run:
            return
            
        # Group accuracies by loss type
        loss_type_data = {}
        for combined_key, accuracies in self.target_accuracies.items():
            loss_name, target_idx = combined_key.split('-')
            target_idx = int(target_idx)
            
            if loss_name not in loss_type_data:
                loss_type_data[loss_name] = {}
            loss_type_data[loss_name][target_idx] = sum(accuracies) / len(accuracies)
        
        # Create separate plots for each loss type
        for loss_name, target_means in loss_type_data.items():
            if target_means and loss_name == 'node_loss':  # Only process node_loss with vocab
                # Sort by accuracy for consistent visualization
                sorted_targets = sorted(target_means.items(), key=lambda x: x[1])
                target_indices = [target_idx for target_idx, _ in sorted_targets]
                mean_accuracies = [acc for _, acc in sorted_targets]
                
                # Get SMILES names from vocabulary
                target_names = [self.tokenizer.vocab_node[idx] if idx < len(self.tokenizer.vocab_node) else f"UNK_{idx}" 
                              for idx in target_indices]
                
                wandb.log({
                    f"train/{loss_name}_target_distribution_table": wandb.Table(
                        data=[[idx, acc, name] for name, acc, idx in zip(target_names, mean_accuracies, target_indices)],
                        columns=["index", "accuracy", "SMILES"]
                    ),
                    f"train/{loss_name}_target_distribution": wandb.plot.bar(
                        wandb.Table(
                            data=[[name, acc] for name, acc in zip(target_names, mean_accuracies)],
                            columns=["SMILES", "accuracy"]
                        ),
                        "SMILES",
                        "accuracy",
                        title=f"{loss_name} Accuracy Distribution Across Targets"
                    )
                }, commit=True)

    def reset(self):
        self.pos_acc = []
        self.x_acc = []
        self.e_acc = []
        self.adj_acc = []
        self.target_accuracies = {}

class TrainLossDiscrete(nn.Module):
    def __init__(self, lambda_train, weight_node=None, weight_edge=None, tokenizer=None):
        super().__init__()
        self.node_loss = CrossEntropyMetric()
        self.edge_loss = CrossEntropyMetric()
        self.pos_loss = CrossEntropyMetric()
        self.y_loss = MSEMetric()

        self.lambda_train = lambda_train
        self.loss_monitor = LossMonitor(tokenizer)
    
    def forward(self, masked_pred_X, masked_pred_E, masked_pred_pos, masked_pred_y, true_X, true_E, true_pos, true_y, node_mask, noise_weight, limit_dist, token_to_atom_count, log: bool):
        """ Compute train metrics
        Args:
            masked_pred_X : tensor -- (bs, n, dx) predicted node features
            masked_pred_E : tensor -- (bs, n, n, de) predicted edge features 
            masked_pred_pos : tensor -- (bs, n, n, d_pos) predicted position features
            masked_pred_y : tensor -- (bs, 1, 1024) predicted fingerprint
            true_X : tensor -- (bs, n, dx) ground truth node features
            true_E : tensor -- (bs, n, n, de) ground truth edge features
            true_pos : tensor -- (bs, n, n, d_pos) ground truth position features
            true_y : tensor -- (bs, 1, 1024) ground truth fingerprint
            node_mask : tensor -- (bs, n) mask for valid nodes
            noise_weight : tensor -- (bs, 1) weight for noise level
            limit_dist : placeholder for limit distance
            token_to_atom_count : tensor -- (num_token, num_elements) token to atom distribution
            log : boolean whether to log metrics
        """

        n_nodes = true_X.shape[1]
        true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx)
        true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de)
        true_pos = torch.reshape(true_pos, (-1, true_pos.size(-1)))  # (bs * n * n, d_pos)

        masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1)))  # (bs * n, dx)
        masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1)))   # (bs * n * n, de)
        masked_pred_pos = torch.reshape(masked_pred_pos, (-1, masked_pred_pos.size(-1)))   # (bs * n * n, d_pos)
        
        # Remove masked rows
        mask_X = (true_X != 0.).any(dim=-1)
        mask_E = (true_E != 0.).any(dim=-1)
        mask_pos = (true_pos != 0.).any(dim=-1)

        flat_true_X = true_X[mask_X, :]
        flat_pred_X = masked_pred_X[mask_X, :]
        flat_limit_dist_X = limit_dist.X.unsqueeze(0).expand(flat_pred_X.size(0), -1)

        flat_true_E = true_E[mask_E, :]
        flat_pred_E = masked_pred_E[mask_E, :]
        flat_limit_dist_E = limit_dist.E.unsqueeze(0).expand(flat_pred_E.size(0), -1)
        
        flat_true_pos = true_pos[mask_pos, :]
        flat_pred_pos = masked_pred_pos[mask_pos, :]
        flat_limit_dist_pos = limit_dist.pos.unsqueeze(0).expand(flat_pred_pos.size(0), -1)
        
        target_weight_X = torch.ones_like(limit_dist.X)

        if noise_weight is not None:
            weight_X = noise_weight.unsqueeze(1).repeat(1, n_nodes, 1).reshape(-1, 1) # (bs, 1) -> (bs, n, 1) -> (bs * n, 1)   
            weight_E = noise_weight.unsqueeze(1).unsqueeze(1).repeat(1, n_nodes, n_nodes, 1).reshape(-1, 1) # (bs, 1) -> (bs, n, n, 1) -> (bs * n * n, 1)
            weight_pos = noise_weight.unsqueeze(1).unsqueeze(1).repeat(1, n_nodes, n_nodes, 1).reshape(-1, 1) # (bs, 1) -> (bs, n, n, 1) -> (bs * n * n, 1)

            weight_X = weight_X[mask_X, :]
            weight_E = weight_E[mask_E, :]
            weight_pos = weight_pos[mask_pos, :]
        else:
            weight_X, weight_E, weight_pos = None, None, None

        # Update accuracies using monitor
        acc_pos = self.loss_monitor.update_accuracies(flat_pred_pos, flat_true_pos, 'pos')
        acc_X = self.loss_monitor.update_accuracies(flat_pred_X, flat_true_X, 'x')
        acc_E = self.loss_monitor.update_accuracies(flat_pred_E, flat_true_E, 'e')
        self.loss_monitor.update_target_acc(flat_pred_X, flat_true_X, 'node_loss')
        self.loss_monitor.update_target_acc(flat_pred_E, flat_true_E, 'edge_loss')
        self.loss_monitor.update_target_acc(flat_pred_pos, flat_true_pos, 'pos_loss')

        loss_X = self.node_loss(flat_pred_X, flat_true_X, sample_weight=weight_X, target_weight=target_weight_X) if true_X.numel() > 0 else 0.0
        loss_E = self.edge_loss(flat_pred_E, flat_true_E, sample_weight=weight_E) if true_E.numel() > 0 else 0.0

        if masked_pred_y is not None and true_y is not None:
            loss_y = self.y_loss(masked_pred_y, torch.zeros_like(masked_pred_y)) 
        else:
            loss_y = 0.0

        loss_pos = self.pos_loss(flat_pred_pos, flat_true_pos, sample_weight=weight_pos) if true_pos.numel() > 0 else 0.0

        lw_X = self.lambda_train[0]
        lw_E = self.lambda_train[1]
        lw_pos = self.lambda_train[2]


        tot_loss = (lw_X * loss_X + 
                    lw_E * loss_E + 
                    lw_pos * loss_pos +
                    loss_y)

        return tot_loss

    def reset(self):
        for metric in [self.node_loss, self.edge_loss, self.pos_loss]:
            metric.reset()

    def log_epoch_metrics(self, current_epoch, start_epoch_time, log=True, finished=False):
        epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
        epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
        epoch_pos_loss = self.pos_loss.compute() if self.pos_loss.total_samples > 0 else -1
        epoch_y_loss = self.y_loss.compute() if self.y_loss.total_samples > 0 else -1

        # if log:
        status = "finished" if finished else ""
        duration = time.time() - start_epoch_time if start_epoch_time else 0

        x_acc = torch.tensor(self.loss_monitor.x_acc).mean()
        pos_acc = torch.tensor(self.loss_monitor.pos_acc).mean()
        e_acc = torch.tensor(self.loss_monitor.e_acc).mean()

        print(f"Epoch {current_epoch} {status}: X_CE: {epoch_node_loss:.4f} -- E_CE: {epoch_edge_loss:.4f} -- Y_BCE: {epoch_y_loss:.4f} -- pos_CE: {epoch_pos_loss:.4f} -- x_acc: {x_acc:.4f} -- e_acc: {e_acc:.4f} -- pos_acc: {pos_acc:.4f} -- Time taken {duration:.1f}s")
        if wandb.run and epoch_node_loss > 0:
            wandb.log({
                "train/epoch_X_CE": epoch_node_loss,
                "train/epoch_E_CE": epoch_edge_loss,
                "train/epoch_Y_BCE": epoch_y_loss,
                "train/epoch_pos_CE": epoch_pos_loss,
                "train/epoch_x_acc": x_acc,
                "train/epoch_e_acc": e_acc,
                "train/epoch_pos_acc": pos_acc,
            }, commit=True)
        
        # Log target-specific losses
        self.loss_monitor.log_target_acc()
        self.loss_monitor.reset()

class ValidateLoss(nn.Module):
    def __init__(self, tokenizer=None):
        super().__init__()
        self.loss_monitor = LossMonitor(tokenizer)
    
    def forward(self, masked_pred_X, masked_pred_E, masked_pred_pos, true_X, true_E, true_pos, node_mask, log: bool):
        """ Compute validation metrics
        Args:
            masked_pred_X : tensor -- (bs, n, dx) predicted node features
            masked_pred_E : tensor -- (bs, n, n, de) predicted edge features 
            masked_pred_pos : tensor -- (bs, n, n, d_pos) predicted position features
            true_X : tensor -- (bs, n, dx) ground truth node features
            true_E : tensor -- (bs, n, n, de) ground truth edge features
            true_pos : tensor -- (bs, n, n, d_pos) ground truth position features
            node_mask : tensor -- (bs, n) mask for valid nodes
            log : boolean whether to log metrics
        """
        
        n_nodes = true_X.shape[1]
        true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx)
        true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de)
        true_pos = torch.reshape(true_pos, (-1, true_pos.size(-1)))  # (bs * n * n, d_pos)

        masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1)))  # (bs * n, dx)
        masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1)))   # (bs * n * n, de)
        masked_pred_pos = torch.reshape(masked_pred_pos, (-1, masked_pred_pos.size(-1)))   # (bs * n * n, d_pos)
        
        # Remove masked rows
        mask_X = (true_X != 0.).any(dim=-1)
        mask_E = (true_E != 0.).any(dim=-1)
        mask_pos = (true_pos != 0.).any(dim=-1)

        flat_true_X = true_X[mask_X, :]
        flat_pred_X = masked_pred_X[mask_X, :]
        
        flat_true_E = true_E[mask_E, :]
        flat_pred_E = masked_pred_E[mask_E, :]
        
        flat_true_pos = true_pos[mask_pos, :]
        flat_pred_pos = masked_pred_pos[mask_pos, :]
        
        # Update accuracies using monitor
        acc_pos = self.loss_monitor.update_accuracies(flat_pred_pos, flat_true_pos, 'pos')
        acc_X = self.loss_monitor.update_accuracies(flat_pred_X, flat_true_X, 'x')
        acc_E = self.loss_monitor.update_accuracies(flat_pred_E, flat_true_E, 'e')
        self.loss_monitor.update_target_acc(flat_pred_X, flat_true_X, 'node_loss')
        self.loss_monitor.update_target_acc(flat_pred_E, flat_true_E, 'edge_loss')
        self.loss_monitor.update_target_acc(flat_pred_pos, flat_true_pos, 'pos_loss')

        if log and wandb.run:
            wandb.log({
                "validate/x_acc": acc_X,
                "validate/e_acc": acc_E,
                "validate/pos_acc": acc_pos,
            }, commit=True)
        return acc_X, acc_E, acc_pos

    def reset(self):
        self.loss_monitor.reset()