import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


class AlternatingOptimizer:
    """Alternating Optimizer (for Wasserstein Barycenter and Invariant Representations)"""

    def __init__(self, model, config):
        self.model = model
        self.config = config

        # Optimizers
        self.optimizer_invariant = optim.Adam(
            model.invariant_encoder.parameters(),
            lr=config.get('lr_invariant', 0.001),
            weight_decay=config.get('wd_invariant', 1e-4)
        )

        self.optimizer_prototype = optim.Adam(
            model.prototype_aligner.parameters(),
            lr=config.get('lr_prototype', 0.001),
            weight_decay=config.get('wd_prototype', 1e-4)
        )

        # Schedulers
        self.scheduler_invariant = optim.lr_scheduler.StepLR(
            self.optimizer_invariant, step_size=50, gamma=0.5
        )

        self.scheduler_prototype = optim.lr_scheduler.StepLR(
            self.optimizer_prototype, step_size=50, gamma=0.5
        )

        # Training history
        self.history = {
            'invariant_loss': [],
            'prototype_loss': [],
            'alignment_loss': []
        }

    def alternating_step(self, batch_data, epoch, step_type='both'):
        """
        Alternating optimization step

        Args:
            batch_data: batch data
            epoch: current epoch
            step_type: 'invariant', 'prototype', or 'both'
        """
        x = batch_data['x']
        adj = batch_data['adj']
        labels = batch_data['y']

        losses = {}

        if step_type in ['invariant', 'both']:
            # Optimize invariant representations
            self.optimizer_invariant.zero_grad()

            outputs = self.model(x, adj, labels, training=True)
            model_losses = outputs['losses']

            # Invariant representation related losses
            invariant_loss = (
                    model_losses['classification'] +
                    model_losses['invariance'] +
                    self.config.get('alpha', 1.0) * model_losses.get('alignment', 0)
            )

            invariant_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.invariant_encoder.parameters(),
                max_norm=1.0
            )
            self.optimizer_invariant.step()

            losses['invariant'] = invariant_loss.item()

        if step_type in ['prototype', 'both']:
            # Optimize prototypes
            self.optimizer_prototype.zero_grad()

            # Re-forward pass to get latest representations
            with torch.no_grad():
                outputs = self.model(x, adj, labels, training=True)
                invariant_embeddings = outputs['invariant_embeddings']

            # Compute prototype alignment loss
            alignment_loss = self.model.prototype_aligner.alignment_loss(
                invariant_embeddings, labels
            )

            alignment_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.prototype_aligner.parameters(),
                max_norm=1.0
            )
            self.optimizer_prototype.step()

            losses['prototype'] = alignment_loss.item()

        # Update learning rates
        if epoch > 0 and epoch % 50 == 0:
            self.scheduler_invariant.step()
            self.scheduler_prototype.step()

        return losses

    def wasserstein_barycenter_iteration(self, invariant_embeddings, labels, num_iterations=10):
        """
        Wasserstein barycenter iteration computation (as shown in Figure 2)

        Args:
            invariant_embeddings: dict {env_id: [batch_size, dim]}
            labels: [batch_size] class labels
            num_iterations: number of iterations

        Returns:
            Updated prototypes and loss
        """
        prototype_aligner = self.model.prototype_aligner
        num_classes = prototype_aligner.num_classes

        # Initialize class prototypes
        if not hasattr(self, 'class_prototypes'):
            self.class_prototypes = {}
            for class_id in range(num_classes):
                self.class_prototypes[class_id] = {
                    'mean': torch.zeros(prototype_aligner.feature_dim,
                                        device=next(prototype_aligner.parameters()).device),
                    'cov': torch.eye(prototype_aligner.feature_dim,
                                     device=next(prototype_aligner.parameters()).device)
                }

        total_alignment_loss = 0

        for iteration in range(num_iterations):
            iteration_loss = 0

            # Step 1: Fix representations, update prototypes
            for class_id in range(num_classes):
                class_distributions = []
                env_weights = []

                for env_id, embeddings in invariant_embeddings.items():
                    # Collect samples of this class
                    class_mask = (labels == class_id)
                    if class_mask.any():
                        class_emb = embeddings[class_mask]

                        # Compute Gaussian parameters
                        mean = class_emb.mean(dim=0)
                        if class_emb.shape[0] > 1:
                            cov = torch.cov(class_emb.T)
                        else:
                            cov = torch.eye(class_emb.shape[1], device=class_emb.device)

                        class_distributions.append((mean, cov))
                        env_weights.append(1.0)

                if len(class_distributions) > 0:
                    # Compute Wasserstein barycenter
                    barycenter = prototype_aligner.compute_barycenter(
                        class_distributions,
                        weights=torch.tensor(env_weights, device=mean.device)
                    )

                    # Update prototype
                    self.class_prototypes[class_id]['mean'] = barycenter[0].detach()
                    self.class_prototypes[class_id]['cov'] = barycenter[1].detach()

            # Step 2: Fix prototypes, update representations (via backpropagation)
            for class_id in range(num_classes):
                if class_id not in self.class_prototypes:
                    continue

                prototype_mean = self.class_prototypes[class_id]['mean']
                prototype_cov = self.class_prototypes[class_id]['cov']

                for env_id, embeddings in invariant_embeddings.items():
                    class_mask = (labels == class_id)
                    if class_mask.any():
                        class_emb = embeddings[class_mask]

                        # Compute alignment loss
                        mean = class_emb.mean(dim=0)
                        if class_emb.shape[0] > 1:
                            cov = torch.cov(class_emb.T)
                        else:
                            cov = torch.eye(class_emb.shape[1], device=class_emb.device)

                        # Wasserstein distance
                        distance = prototype_aligner.wasserstein_loss(
                            prototype_mean, prototype_cov, mean, cov
                        )

                        iteration_loss += distance

            total_alignment_loss += iteration_loss

            if iteration % 5 == 0:
                print(f"  WB Iteration {iteration}: Loss = {iteration_loss.item():.4f}")

        # Average loss
        avg_loss = total_alignment_loss / num_iterations

        return self.class_prototypes, avg_loss

    def train_epoch_alternating(self, train_loader, epoch):
        """Alternating training for one epoch"""
        self.model.train()

        total_loss = 0
        pbar = tqdm(train_loader, desc=f'Alternating Epoch {epoch}')

        for batch_idx, batch_data in enumerate(pbar):
            # Convert to dense format
            batch_dense = self._batch_to_dense(batch_data)

            # Alternating optimization step
            if epoch % 2 == 0:
                # Even epochs: optimize invariant representations
                losses = self.alternating_step(batch_dense, epoch, step_type='invariant')
            else:
                # Odd epochs: optimize prototypes
                losses = self.alternating_step(batch_dense, epoch, step_type='prototype')

            # Record losses
            for key, value in losses.items():
                self.history[f'{key}_loss'].append(value)
                total_loss += value

            # Update progress bar
            pbar.set_postfix({k: f'{v:.4f}' for k, v in losses.items()})

            # Perform Wasserstein barycenter iteration every 10 batches
            if batch_idx % 10 == 0 and hasattr(self.model, 'prototype_aligner'):
                with torch.no_grad():
                    outputs = self.model(batch_dense['x'], batch_dense['adj'],
                                         batch_dense['y'], training=True)

                    prototypes, wb_loss = self.wasserstein_barycenter_iteration(
                        outputs['invariant_embeddings'],
                        batch_dense['y'],
                        num_iterations=3
                    )

                    self.history['alignment_loss'].append(wb_loss.item())

        avg_loss = total_loss / len(train_loader)
        return avg_loss

    def _batch_to_dense(self, batch):
        """Convert PyG batch to dense format"""
        from data.utils import graph_to_dense_batch

        # Assume batch is a PyG Batch object
        if hasattr(batch, 'batch'):
            # Separate into individual graphs
            data_list = batch.to_data_list()
            return graph_to_dense_batch(data_list)
        else:
            # Already in dense format
            return batch