import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np


class SimpleGIPLDTrainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')

        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config.get('lr', 0.001),
            weight_decay=config.get('weight_decay', 1e-4)
        )

        self.history = {
            'train_loss': [],
            'val_acc': [],
            'best_val_acc': 0
        }

    def batch_to_dense(self, batch):
        """Convert PyG batch to dense format"""
        from torch_geometric.utils import to_dense_adj

        # Separate graphs
        data_list = batch.to_data_list()

        # Convert to dense format
        batch_size = len(data_list)
        max_nodes = max(data.num_nodes for data in data_list)

        # Assume feature dimension is 10 (virtual dataset)
        feature_dim = 10
        x_batch = torch.zeros(batch_size, max_nodes, feature_dim)
        adj_batch = torch.zeros(batch_size, max_nodes, max_nodes)
        y_batch = torch.zeros(batch_size, dtype=torch.long)

        for i, data in enumerate(data_list):
            num_nodes = data.num_nodes
            x_batch[i, :num_nodes] = data.x[:, :feature_dim]  # Only take first 10 dimensions

            # Create adjacency matrix
            if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0:
                adj = to_dense_adj(data.edge_index, max_num_nodes=num_nodes)[0]
                adj_batch[i, :num_nodes, :num_nodes] = adj
            else:
                # If no edges, create fully connected graph
                adj_batch[i, :num_nodes, :num_nodes] = 1

            y_batch[i] = data.y

        return {
            'x': x_batch.to(self.device),
            'adj': adj_batch.to(self.device),
            'y': y_batch.to(self.device)
        }

    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0

        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
        for batch in pbar:
            # Convert to dense format
            dense_batch = self.batch_to_dense(batch)

            # Forward pass
            outputs = self.model(
                dense_batch['x'],
                dense_batch['adj'],
                dense_batch['y'],
                training=True
            )

            losses = outputs['losses']

            # Total loss (simplified: only use classification loss)
            loss = losses.get('classification', 0)
            if torch.is_tensor(loss):
                loss = loss

            # Backward propagation
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        return total_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in self.val_loader:
                dense_batch = self.batch_to_dense(batch)

                outputs = self.model(
                    dense_batch['x'],
                    dense_batch['adj'],
                    training=False
                )

                # Use predictions from first environment
                predictions = outputs['predictions'][0]['invariant']
                _, predicted = torch.max(predictions, 1)

                total += dense_batch['y'].size(0)
                correct += (predicted == dense_batch['y']).sum().item()

        accuracy = 100 * correct / total
        return accuracy

    def train(self, epochs=None):
        if epochs is None:
            epochs = self.config.get('epochs', 10)

        for epoch in range(epochs):
            # Train for one epoch
            train_loss = self.train_epoch(epoch)
            self.history['train_loss'].append(train_loss)

            # Validate
            val_acc = self.validate()
            self.history['val_acc'].append(val_acc)

            print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Acc = {val_acc:.2f}%')

            # Save best model
            if val_acc > self.history['best_val_acc']:
                self.history['best_val_acc'] = val_acc
                torch.save(self.model.state_dict(), 'best_model_simple.pth')
                print(f'  ✅ Saved best model, validation accuracy: {val_acc:.2f}%')

        # Visualization
        self.plot_training_history()

        return self.history

    def plot_training_history(self):
        """Plot training history"""
        plt.figure(figsize=(10, 4))

        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'])
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 2, 2)
        plt.plot(self.history['val_acc'])
        plt.title('Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig('simple_training_history.png', dpi=150)
        print(f"Training history plot saved to: simple_training_history.png")

        plt.show()