import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import numpy as np


class CompleteDIGLTrainer:
    """Fixed trainer for CompleteDIGL model"""

    def __init__(self, model, train_loader, val_loader, test_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config

        self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        # Ensure all numerical parameters are correct types
        self.epochs = int(config.get('epochs', 100))
        self.lr = float(config.get('lr', 0.001))
        self.weight_decay = float(config.get('weight_decay', 1e-4))

        # Loss weights
        self.alpha = float(config.get('alpha', 1.0))  # Prototype alignment
        self.beta = float(config.get('beta', 1.0))    # Disentanglement
        self.gamma = float(config.get('gamma', 0.5))  # Causal intervention

        print(f"   Trainer configuration:")
        print(f"     Epochs: {self.epochs}")
        print(f"     Learning rate: {self.lr}")
        print(f"     Weight decay: {self.weight_decay}")

        # Optimizer - ensure parameters are correct types
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
            betas=(0.9, 0.999)
        )

        # Simple learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=self.epochs
        )

        # Early stopping
        self.patience = int(config.get('patience', 10))
        self.best_val_acc = 0.0
        self.patience_counter = 0
        self.best_model_state = None

        # Output directory
        self.output_dir = Path(config.get('output_dir', './results/good'))
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def train_epoch(self, epoch):
        """Train one epoch"""
        self.model.train()

        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        for batch_idx, batch in enumerate(self.train_loader):
            # Move batch to device
            batch = self._move_to_device(batch)

            # Get data
            labels = batch.get('y')
            env_labels = batch.get('env')

            if labels is not None and isinstance(labels, torch.Tensor):
                labels = labels.squeeze()

            if env_labels is not None and isinstance(env_labels, torch.Tensor):
                env_labels = env_labels.squeeze()

            # Forward pass
            output = self.model(batch, labels=labels, env_labels=env_labels, training=True)

            # Get loss
            loss = output.get('total_loss', torch.tensor(0.0, device=self.device))

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Optimizer step
            self.optimizer.step()

            # Statistics
            if labels is not None:
                batch_size = labels.size(0)
                total_loss += loss.item() * batch_size

                if 'accuracy' in output:
                    total_correct += output['accuracy'].item() * batch_size
                    total_samples += batch_size

        avg_loss = total_loss / total_samples if total_samples > 0 else total_loss
        accuracy = total_correct / total_samples if total_samples > 0 else 0.0

        return avg_loss, accuracy

    def validate(self):
        """Validation"""
        self.model.eval()

        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for batch in self.val_loader:
                batch = self._move_to_device(batch)

                labels = batch.get('y')
                if labels is not None and isinstance(labels, torch.Tensor):
                    labels = labels.squeeze()

                    output = self.model(batch, training=False)
                    task_logits = output.get('task_logits')

                    if task_logits is not None:
                        preds = task_logits.argmax(dim=1)
                        total_correct += (preds == labels).sum().item()
                        total_samples += labels.size(0)

        accuracy = total_correct / total_samples if total_samples > 0 else 0.0
        return accuracy

    def test(self):
        """Testing"""
        self.model.eval()

        stats = {
            'total_correct': 0,
            'total_samples': 0,
            'env_correct': {},
            'env_total': {}
        }

        with torch.no_grad():
            for batch in self.test_loader:
                batch = self._move_to_device(batch)

                labels = batch.get('y')
                envs = batch.get('env')

                if labels is not None and isinstance(labels, torch.Tensor):
                    labels_np = labels.squeeze().cpu().numpy()

                    output = self.model(batch, training=False)
                    task_logits = output.get('task_logits')

                    if task_logits is not None:
                        preds = task_logits.argmax(dim=1).cpu().numpy()

                        # Overall accuracy
                        stats['total_correct'] += (preds == labels_np).sum()
                        stats['total_samples'] += len(labels_np)

                        # Per-environment accuracy
                        if envs is not None and isinstance(envs, torch.Tensor):
                            envs_np = envs.squeeze().cpu().numpy()
                            for env in np.unique(envs_np):
                                mask = envs_np == env
                                if mask.any():
                                    env_key = int(env)
                                    if env_key not in stats['env_correct']:
                                        stats['env_correct'][env_key] = 0
                                        stats['env_total'][env_key] = 0

                                    stats['env_correct'][env_key] += (preds[mask] == labels_np[mask]).sum()
                                    stats['env_total'][env_key] += mask.sum()

        # Calculate final metrics
        overall_acc = stats['total_correct'] / stats['total_samples'] if stats['total_samples'] > 0 else 0.0

        # Environment accuracies
        env_accuracies = {}
        for env in stats['env_correct']:
            env_accuracies[env] = stats['env_correct'][env] / stats['env_total'][env]

        # Fairness gap
        fairness_gap = 0.0
        if len(env_accuracies) >= 2:
            env_accs = list(env_accuracies.values())
            fairness_gap = max(env_accs) - min(env_accs)

        return {
            'accuracy': overall_acc,
            'fairness_gap': fairness_gap,
            'env_accuracies': env_accuracies,
            'total_samples': stats['total_samples'],
            'total_correct': stats['total_correct']
        }

    def train(self):
        """Main training loop"""
        print(f"Training for {self.epochs} epochs...")

        history = {
            'train_loss': [],
            'train_acc': [],
            'val_acc': [],
            'best_val_acc': 0.0
        }

        for epoch in range(self.epochs):
            # Train
            train_loss, train_acc = self.train_epoch(epoch)

            # Validate
            val_acc = self.validate()

            # Update scheduler
            self.scheduler.step()

            # Record history
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)

            # Print progress
            print(f"Epoch {epoch+1:3d}/{self.epochs}: "
                  f"Loss={train_loss:.4f}, "
                  f"Train Acc={train_acc:.2%}, "
                  f"Val Acc={val_acc:.2%}")

            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.patience_counter = 0
                self.best_model_state = self.model.state_dict().copy()
                history['best_val_acc'] = val_acc

                # Save checkpoint
                self._save_checkpoint(epoch, val_acc)
                print(f"   🔥 New best validation accuracy: {val_acc:.2%}")
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.patience:
                    print(f"   ⏹️ Early stopping at epoch {epoch+1}")
                    break

        # Load best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            print(f"Restored best model with val acc: {self.best_val_acc:.2%}")

        return history

    def _save_checkpoint(self, epoch, val_acc):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_accuracy': val_acc,
            'config': self.config
        }

        checkpoint_path = self.output_dir / f"checkpoint_epoch{epoch+1}_acc{val_acc:.4f}.pth"
        torch.save(checkpoint, checkpoint_path)

    def _move_to_device(self, batch):
        """Move batch to device"""
        if isinstance(batch, dict):
            return {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                   for k, v in batch.items()}
        elif isinstance(batch, torch.Tensor):
            return batch.to(self.device)
        return batch