import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import time

# Handle wandb import gracefully
try:
    import wandb

    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("⚠️  wandb not available, using local logging only")


class DIGLTrainer:
    """
    DIGL Trainer - Handles training, validation, and testing
    """

    def __init__(self, model, train_loader, val_loader, test_loader, config):
        """
        Initialize trainer

        Args:
            model: DIGL model
            train_loader: Training data loader
            val_loader: Validation data loader  
            test_loader: Test data loader
            config: Training configuration dictionary
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config

        # Device
        self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        # Optimizer
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=config.get('lr', 0.001),
            weight_decay=config.get('weight_decay', 1e-4)
        )

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=10, verbose=True
        )

        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'test_acc': None,
            'best_val_acc': 0,
            'best_epoch': 0
        }

        # Create output directory
        self.output_dir = Path(config.get('output_dir', './results'))
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Initialize wandb if available
        self.use_wandb = config.get('use_wandb', False) and WANDB_AVAILABLE
        if self.use_wandb:
            wandb.init(project="DIGL", config=config)

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
        for batch_idx, batch in enumerate(pbar):
            # Move data to device
            batch = batch.to(self.device)

            # Forward pass
            output = self.model(
                x=batch.x,
                adj=None,  # Virtual dataset doesn't have adjacency
                labels=batch.y,
                training=True
            )

            # Calculate loss
            loss = output['losses']['total_loss']

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            # Statistics
            total_loss += loss.item()
            preds = output['logits'].argmax(dim=1)
            total_correct += (preds == batch.y).sum().item()
            total_samples += len(batch.y)

            # Update progress bar
            current_acc = total_correct / total_samples if total_samples > 0 else 0
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{current_acc:.2%}'
            })

            # Log to wandb
            if self.use_wandb and batch_idx % 10 == 0:
                wandb.log({
                    'batch_loss': loss.item(),
                    'batch_accuracy': (preds == batch.y).float().mean().item()
                })

        # Calculate epoch statistics
        avg_loss = total_loss / len(self.train_loader)
        train_acc = total_correct / total_samples if total_samples > 0 else 0

        # Save to history
        self.history['train_loss'].append(avg_loss)
        self.history['train_acc'].append(train_acc)

        # Log to wandb
        if self.use_wandb:
            wandb.log({
                'epoch': epoch,
                'train_loss': avg_loss,
                'train_accuracy': train_acc
            })

        return avg_loss, train_acc

    def validate(self, loader=None):
        """Validate model"""
        if loader is None:
            loader = self.val_loader

        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for batch in loader:
                batch = batch.to(self.device)

                # Forward pass
                output = self.model(
                    x=batch.x,
                    adj=None,
                    labels=batch.y,
                    training=True  # Need labels for loss calculation
                )

                # Calculate loss and accuracy
                loss = output['losses']['total_loss']
                preds = output['logits'].argmax(dim=1)

                total_loss += loss.item()
                total_correct += (preds == batch.y).sum().item()
                total_samples += len(batch.y)

        avg_loss = total_loss / len(loader) if len(loader) > 0 else 0
        val_acc = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, val_acc

    def test(self):
        """Test model"""
        self.model.eval()
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for batch in self.test_loader:
                batch = batch.to(self.device)

                # Forward pass
                output = self.model(
                    x=batch.x,
                    adj=None,
                    training=False
                )

                # Calculate accuracy
                preds = output['logits'].argmax(dim=1)
                total_correct += (preds == batch.y).sum().item()
                total_samples += len(batch.y)

        test_acc = total_correct / total_samples if total_samples > 0 else 0
        self.history['test_acc'] = test_acc

        # Log to wandb
        if self.use_wandb:
            wandb.log({'test_accuracy': test_acc})

        return {'accuracy': test_acc}

    def train(self, epochs=None):
        """Complete training process"""
        if epochs is None:
            epochs = self.config.get('epochs', 50)

        print(f"\nStarting training for {epochs} epochs")
        print("-" * 60)

        start_time = time.time()

        for epoch in range(1, epochs + 1):
            # Train
            train_loss, train_acc = self.train_epoch(epoch)

            # Validate
            val_loss, val_acc = self.validate()

            # Save to history
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)

            # Update learning rate
            self.scheduler.step(val_acc)

            # Save best model
            if val_acc > self.history['best_val_acc']:
                self.history['best_val_acc'] = val_acc
                self.history['best_epoch'] = epoch
                self.save_model("best_model.pth")

            # Print epoch results
            print(f"Epoch {epoch:3d}/{epochs}: "
                  f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2%} | "
                  f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2%} | "
                  f"Best Val Acc={self.history['best_val_acc']:.2%} (epoch {self.history['best_epoch']})")

            # Save checkpoint every 10 epochs
            if epoch % 10 == 0:
                self.save_model(f"checkpoint_epoch_{epoch}.pth")
                self.save_history()

        # Training complete
        training_time = time.time() - start_time
        print(f"\nTraining completed in {training_time:.2f} seconds!")
        print(f"Best validation accuracy: {self.history['best_val_acc']:.2%} (epoch {self.history['best_epoch']})")

        # Final test
        test_results = self.test()
        print(f"Test accuracy: {test_results['accuracy']:.2%}")

        # Save final model and history
        self.save_model("final_model.pth")
        self.save_history()

        # Plot training curves
        self.plot_training_curves()

        return self.history

    def save_model(self, filename):
        """Save model checkpoint"""
        model_path = self.output_dir / filename
        torch.save({
            'epoch': self.history.get('best_epoch', 0),
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config,
            'history': self.history,
            'best_val_acc': self.history['best_val_acc']
        }, model_path)
        print(f"Model saved to: {model_path}")

    def save_history(self):
        """Save training history"""
        history_path = self.output_dir / "training_history.json"

        # Convert to serializable format
        history_dict = {
            'train_loss': [float(x) for x in self.history['train_loss']],
            'train_acc': [float(x) for x in self.history['train_acc']],
            'val_loss': [float(x) for x in self.history['val_loss']],
            'val_acc': [float(x) for x in self.history['val_acc']],
            'test_acc': float(self.history['test_acc']) if self.history['test_acc'] is not None else 0.0,
            'best_val_acc': float(self.history['best_val_acc']),
            'best_epoch': int(self.history['best_epoch'])
        }

        with open(history_path, 'w') as f:
            json.dump(history_dict, f, indent=2)

        print(f"Training history saved to: {history_path}")

    def plot_training_curves(self):
        """Plot training curves"""
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))

        # Loss curve
        epochs = range(1, len(self.history['train_loss']) + 1)
        axes[0].plot(epochs, self.history['train_loss'], label='Train Loss', linewidth=2)
        axes[0].plot(epochs, self.history['val_loss'], label='Val Loss', linewidth=2)
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Accuracy curve
        axes[1].plot(epochs, self.history['train_acc'], label='Train Acc', linewidth=2)
        axes[1].plot(epochs, self.history['val_acc'], label='Val Acc', linewidth=2)
        axes[1].axhline(y=self.history['best_val_acc'], color='red', linestyle='--',
                        label=f'Best Val: {self.history["best_val_acc"]:.2%}')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy')
        axes[1].set_title('Training and Validation Accuracy')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

        # Save figure
        plot_path = self.output_dir / "training_curves.png"
        plt.tight_layout()
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"Training curves saved to: {plot_path}")


# For backward compatibility
GIPLDTrainer = DIGLTrainer