import sys
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Data

from greatx.training.trainer import Trainer


class GRANDTrainer(Trainer):
    """A specialized trainer for GRAND models.
    
    This trainer handles the special training procedure of GRAND which includes:
    - Multiple forward passes for consistency regularization
    - Custom loss calculation with consistency term
    - Special handling for dropnode augmentation
    
    Parameters
    ----------
    model : nn.Module
        the GRAND model used for training
    device : Union[str, torch.device], optional
        the device used for training, by default 'cpu'
    n_samples : int, optional
        number of samples for consistency regularization, by default 2
    reg_consistency : float, optional
        weight for consistency regularization, by default 1.0
    sharpening_temperature : float, optional
        temperature for sharpening in consistency loss, by default 0.5
    cfg : other keyword arguments, such as `lr` and `weight_decay`.
    
    Example
    -------
    >>> from greatx.nn.models.supervised import GRAND
    >>> from greatx.training import GRANDTrainer
    >>> model = GRAND(num_features, num_classes)
    >>> trainer = GRANDTrainer(model, device='cuda', n_samples=2, reg_consistency=1.0)
    >>> trainer.fit(data, data.train_mask)
    """
    
    def __init__(self, model, device='cpu', n_samples: int = 2, 
                 reg_consistency: float = 1.0, sharpening_temperature: float = 0.5, **cfg):
        super().__init__(model, device, **cfg)
        self.n_samples = n_samples
        self.reg_consistency = reg_consistency
        self.sharpening_temperature = sharpening_temperature
        
    def train_step(self, data: Data, mask: Optional[Tensor] = None) -> dict:
        """One-step training on the inputs with GRAND's special procedure.

        Parameters
        ----------
        data : Data
            the training data.
        mask : Optional[Tensor]
            the mask of training nodes.

        Returns
        -------
        dict
            the output logs, including `loss` and `acc`, etc.
        """
        model = self.model
        self.callbacks.on_train_batch_begin(0)

        model.train()
        data = data.to(self.device)
        y = data.y.squeeze()

        # Prepare forward arguments
        adj_t = getattr(data, 'adj_t', None)
        if adj_t is None:
            fw_args = (data.x, data.edge_index, getattr(data, 'edge_weight', None))
        else:
            fw_args = (data.x, adj_t, None)

        # Generate multiple forward passes for consistency regularization
        train_scores = [model(*fw_args) for _ in range(self.n_samples)]

        # Apply mask if provided
        if mask is not None:
            train_scores = [score[mask] for score in train_scores]
            y = y[mask]

        # Main cross-entropy loss
        train_cost = sum(F.cross_entropy(s, y) for s in train_scores) / self.n_samples

        # Consistency regularizer
        train_scores_softmax = [s.softmax(dim=-1) for s in train_scores]
        powed_avg = (sum(train_scores_softmax) / self.n_samples).pow(1 / self.sharpening_temperature)
        sharpened = (powed_avg / powed_avg.sum(dim=1, keepdim=True)).detach()
        consistency = sum((s - sharpened).square().sum(dim=1).mean() 
                         for s in train_scores_softmax) / self.n_samples

        # Total loss
        loss = train_cost + self.reg_consistency * consistency

        # Backward pass
        loss.backward()
        self.callbacks.on_train_batch_end(0)

        # Calculate accuracy using the first prediction
        acc = train_scores[0].argmax(-1).eq(y).float().mean().item()
        
        return dict(loss=loss.item(), acc=acc)
    
    def test_step(self, data: Data, mask: Optional[Tensor] = None) -> dict:
        """One-step evaluation on the inputs.
        
        For evaluation, we use a single forward pass since consistency
        regularization is only needed during training.

        Parameters
        ----------
        data : Data
            the testing data.
        mask : Optional[Tensor]
            the mask of testing nodes.

        Returns
        -------
        dict
            the output logs, including `loss` and `acc`, etc.
        """
        model = self.model
        model.eval()
        data = data.to(self.device)
        y = data.y.squeeze()

        # Single forward pass for evaluation
        adj_t = getattr(data, 'adj_t', None)
        if adj_t is None:
            out = model(data.x, data.edge_index, getattr(data, 'edge_weight', None))
        else:
            out = model(data.x, adj_t, None)

        if mask is not None:
            out = out[mask]
            y = y[mask]

        loss = F.cross_entropy(out, y)
        acc = out.argmax(-1).eq(y).float().mean().item()

        return dict(loss=loss.item(), acc=acc)
    
    def predict_step(self, data: Data, mask: Optional[Tensor] = None) -> Tensor:
        """One-step prediction on the inputs.

        Parameters
        ----------
        data : Data
            the prediction data.
        mask : Optional[Tensor]
            the mask of prediction nodes.

        Returns
        -------
        Tensor
            the output prediction.
        """
        model = self.model
        model.eval()
        data = data.to(self.device)

        adj_t = getattr(data, 'adj_t', None)
        if adj_t is None:
            out = model(data.x, data.edge_index, getattr(data, 'edge_weight', None))
        else:
            out = model(data.x, adj_t, None)

        if mask is not None:
            out = out[mask]
        return out 