# training/disc_trainer.py
"""
DisC Dataset Trainer
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from pathlib import Path
import yaml
import time


class DisCTrainer:
    """DisC Dataset Trainer"""

    def __init__(self, model, train_loader, val_loader, test_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config

        self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        # Training parameters
        self.epochs = config.get('epochs', 50)
        self.lr = config.get('lr', 0.001)
        self.weight_decay = config.get('weight_decay', 1e-4)
        self.lambda_w = config.get('lambda_w', 1.0)  # Wasserstein loss weight
        self.lambda_c = config.get('lambda_c', 0.5)  # Causal loss weight

        # Optimizer
        self.optimizer = AdamW(
            model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )

        # Learning rate scheduler
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='max',
            factor=0.5,
            patience=5,
            verbose=True
        )

        # Early stopping
        self.patience = config.get('patience', 10)
        self.best_val_acc = 0
        self.patience_counter = 0
        self.best_model_state = None

        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'best_val_acc': 0,
            'test_acc': 0
        }

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        for batch_idx, batch in enumerate(self.train_loader):
            # Get data
            images = batch['image'].to(self.device)
            labels = batch['label'].to(self.device)
            envs = batch['environment'].to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            output = self.model(images, env_labels=envs)
            task_logits = output['task_logits']

            # Compute task loss
            task_loss = F.cross_entropy(task_logits, labels)

            # Compute Wasserstein loss
            if self.model.use_wasserstein:
                features = self.model.feature_extractor(images)
                wasserstein_loss = self.model.compute_wasserstein_loss(
                    features, envs, lambda_w=self.lambda_w
                )
            else:
                wasserstein_loss = torch.tensor(0.0, device=self.device)

            # Compute causal loss
            if self.model.use_causal_intervention:
                causal_loss = self.model.compute_causal_loss(features, labels, lambda_c=self.lambda_c)
            else:
                causal_loss = torch.tensor(0.0, device=self.device)

            # Total loss
            total_batch_loss = task_loss + wasserstein_loss + causal_loss

            # Backward pass
            total_batch_loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Update parameters
            self.optimizer.step()

            # Statistics
            total_loss += total_batch_loss.item()
            preds = task_logits.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

            # Print progress
            if (batch_idx + 1) % 50 == 0:
                print(f"   Batch {batch_idx + 1}/{len(self.train_loader)}: "
                      f"Loss={total_batch_loss.item():.4f}, "
                      f"Accuracy={(preds == labels).float().mean().item():.2%}")

        avg_loss = total_loss / len(self.train_loader)
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def validate(self):
        """Validate the model"""
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for batch in self.val_loader:
                images = batch['image'].to(self.device)
                labels = batch['label'].to(self.device)

                output = self.model(images)
                task_logits = output['task_logits']

                loss = F.cross_entropy(task_logits, labels)
                total_loss += loss.item()

                preds = task_logits.argmax(dim=1)
                total_correct += (preds == labels).sum().item()
                total_samples += labels.size(0)

        avg_loss = total_loss / len(self.val_loader)
        accuracy = total_correct / total_samples if total_samples > 0 else 0

        return avg_loss, accuracy

    def test(self):
        """Test the model"""
        self.model.eval()
        total_correct = 0
        total_samples = 0

        # Statistics by environment
        env_correct = {}
        env_total = {}

        with torch.no_grad():
            for batch in self.test_loader:
                images = batch['image'].to(self.device)
                labels = batch['label'].to(self.device)
                envs = batch['environment'].to(self.device).cpu().numpy()

                output = self.model(images)
                preds = output['task_logits'].argmax(dim=1).cpu().numpy()
                labels_np = labels.cpu().numpy()

                # Overall statistics
                total_correct += (preds == labels_np).sum()
                total_samples += len(labels_np)

                # Environment statistics
                for env in np.unique(envs):
                    mask = envs == env
                    if mask.any():
                        env_correct[env] = env_correct.get(env, 0) + (preds[mask] == labels_np[mask]).sum()
                        env_total[env] = env_total.get(env, 0) + mask.sum()

        overall_acc = total_correct / total_samples if total_samples > 0 else 0

        # Compute environment accuracies
        env_accuracies = {}
        for env in env_correct:
            env_accuracies[env] = env_correct[env] / env_total[env]

        # Compute fairness gap
        if len(env_accuracies) >= 2:
            env_accs = list(env_accuracies.values())
            fairness_gap = max(env_accs) - min(env_accs)
        else:
            fairness_gap = 0

        results = {
            'accuracy': overall_acc,
            'env_accuracies': env_accuracies,
            'fairness_gap': fairness_gap,
            'total_samples': total_samples,
            'total_correct': total_correct
        }

        return results

    def train(self):
        """Complete training process"""
        print(f"Starting training, total {self.epochs} epochs...")

        for epoch in range(self.epochs):
            start_time = time.time()

            # Training
            train_loss, train_acc = self.train_epoch(epoch)

            # Validation
            val_loss, val_acc = self.validate()

            # Update learning rate
            self.scheduler.step(val_acc)

            # Record history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)

            # Print progress
            epoch_time = time.time() - start_time
            print(f"Epoch {epoch + 1:3d}/{self.epochs}: "
                  f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2%}, "
                  f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2%}, "
                  f"Time={epoch_time:.1f}s")

            # Early stopping check
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.patience_counter = 0
                self.best_model_state = self.model.state_dict().copy()
                self.history['best_val_acc'] = val_acc
                print(f"   New best validation accuracy: {val_acc:.2%}")
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.patience:
                    print(f"   Early stopping triggered, stopping after {epoch + 1} epochs")
                    break

        # Restore best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            print(f"Restored best model, validation accuracy: {self.best_val_acc:.2%}")

        return self.history

    def save_results(self, output_dir, dataset_name):
        """Save results"""
        output_dir = Path(output_dir) / dataset_name
        output_dir.mkdir(parents=True, exist_ok=True)

        # Save model
        model_path = output_dir / "best_model.pth"
        torch.save(self.model.state_dict(), model_path)

        # Save training history
        history_path = output_dir / "training_history.pth"
        torch.save(self.history, history_path)

        # Save configuration
        config_path = output_dir / "config.yaml"
        with open(config_path, 'w') as f:
            yaml.dump(self.config, f)

        return str(output_dir)