#!/usr/bin/env python3

import argparse
import logging
import sys
import torch
from pathlib import Path

from motifagent.config import MotifAgentConfig, create_reconstruction_config, create_generation_config
from motifagent.agents.actor import LLMActor
from motifagent.agents.critic import CentralizedCritic
from motifagent.agents.coordinator import CentralizedCoordinator
from motifagent.environment import AssemblyEnvironment
from motifagent.rewards import RewardSystem
from motifagent.training.mappo import MAPPOTrainer
from motifagent.training.curriculum import CurriculumScheduler, CurriculumLearning
from motifagent.utils.io import DataLoader


def setup_logging(config: MotifAgentConfig):
    log_dir = Path(config.logging.log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)

    logging.basicConfig(
        level=getattr(logging, config.logging.log_level),
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_dir / 'training.log'),
            logging.StreamHandler(sys.stdout)
        ]
    )


def create_models(config: MotifAgentConfig, device: torch.device):
    logging.info("Creating models...")

    # Create actor
    actor = LLMActor(
        model_name=config.model.llm_model_name,
        max_length=config.model.max_length
    ).to(device)

    # Create critic
    critic = CentralizedCritic(
        llm_model_name=config.model.llm_model_name,
        hidden_dim=config.model.hidden_dim,
        max_motifs=config.model.max_motifs
    ).to(device)

    # Create coordinator
    coordinator = CentralizedCoordinator()

    return actor, critic, coordinator


def create_training_components(config: MotifAgentConfig, actor, critic, coordinator):
    logging.info("Creating training components...")

    # Create reward system
    reward_system = RewardSystem(
        chemical_weight=config.rewards.chemical_weight,
        topological_weight=config.rewards.topological_weight,
        chemical_config=config.rewards.chemical_weights,
        topological_config=config.rewards.topological_weights
    )

    # Create trainer
    trainer = MAPPOTrainer(
        actor=actor,
        critic=critic,
        coordinator=coordinator,
        reward_system=reward_system,
        learning_rate=config.training.learning_rate,
        clip_epsilon=config.training.clip_epsilon,
        entropy_coef=config.training.entropy_coef,
        value_loss_coef=config.training.value_loss_coef,
        max_grad_norm=config.training.max_grad_norm,
        ppo_epochs=config.training.ppo_epochs,
        mini_batch_size=config.training.mini_batch_size,
        gae_lambda=config.training.gae_lambda,
        gamma=config.training.gamma,
        set_bc_weight=config.training.set_bc_weight,
        kl_coef=config.training.kl_coef
    )

    return trainer


def create_curriculum(config: MotifAgentConfig):
    if config.training.use_curriculum:
        scheduler = CurriculumScheduler(total_iterations=config.training.num_iterations)
        curriculum = CurriculumLearning(scheduler)
        return curriculum
    return None


def load_training_data(config: MotifAgentConfig):
    logging.info("Loading training data...")

    if config.data.train_data_path:
        train_data = DataLoader.load_smiles_dataset(
            config.data.train_data_path,
            smiles_column=config.data.smiles_column,
            properties_columns=config.data.properties_columns
        )

        if config.data.max_molecules_per_dataset:
            train_data = train_data[:config.data.max_molecules_per_dataset]

        validation_results = DataLoader.validate_smiles_dataset(train_data)
        logging.info(f"Training data loaded: {validation_results}")

        return train_data
    else:
        logging.warning("No training data specified, using synthetic examples")
        return []


def create_environment(config: MotifAgentConfig):
    env = AssemblyEnvironment(
        max_steps=config.environment.max_steps,
        max_motifs=config.model.max_motifs,
        mode=config.environment.mode,
        chemical_validation=config.environment.chemical_validation,
        topology_validation=config.environment.topology_validation
    )
    return env


def train_model(config: MotifAgentConfig, output_dir: Path):
    # Setup
    device = config.get_device()
    logging.info(f"Using device: {device}")

    # Set seed for reproducibility
    if config.system.seed is not None:
        torch.manual_seed(config.system.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(config.system.seed)

    # Create models
    actor, critic, coordinator = create_models(config, device)

    # Create training components
    trainer = create_training_components(config, actor, critic, coordinator)

    # Create curriculum learning
    curriculum = create_curriculum(config)

    # Load data
    training_data = load_training_data(config)

    # Create environment
    env = create_environment(config)

    # Training loop
    logging.info("Starting training...")

    best_performance = -float('inf')
    checkpoint_dir = output_dir / 'checkpoints'
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    for iteration in range(config.training.num_iterations):
        logging.info(f"Training iteration {iteration + 1}/{config.training.num_iterations}")

        # Update curriculum if enabled
        if curriculum:
            curriculum.scheduler.step()

            should_transition, new_stage = curriculum.scheduler.should_transition_stage()
            if should_transition:
                logging.info(f"Curriculum stage transition: {new_stage}")

                stage_config = curriculum.scheduler.get_current_config()
                curriculum.update_trainer_config(trainer, stage_config)
                curriculum.update_reward_system(trainer.reward_system, stage_config)

        # Collect trajectories and train
        try:
            trajectories = trainer.collect_trajectories(
                env,
                num_episodes=config.training.num_episodes_per_iteration,
                max_steps=config.environment.max_steps
            )

            # Create batch and train
            batch = trainer.create_ppo_batch(trajectories)
            loss_dict = trainer.train_ppo_step(batch)
            trainer.update_training_stats(batch, loss_dict)

            # Log progress
            if iteration % 10 == 0:
                stats = trainer.get_training_stats()
                logging.info(f"Iteration {iteration}: {stats}")

                # Log curriculum info if enabled
                if curriculum:
                    progress_info = curriculum.scheduler.get_progress_info()
                    logging.info(f"Curriculum progress: {progress_info}")

            # Save checkpoint
            if iteration % config.logging.save_frequency == 0:
                checkpoint_path = checkpoint_dir / f'checkpoint_{iteration}.pt'
                trainer.save_checkpoint(str(checkpoint_path))
                logging.info(f"Saved checkpoint: {checkpoint_path}")

                # Evaluate and save best model
                current_performance = stats['rewards_mean']['mean'] if 'rewards_mean' in stats else 0.0
                if current_performance > best_performance:
                    best_performance = current_performance
                    best_model_path = output_dir / 'best_model.pt'
                    trainer.save_checkpoint(str(best_model_path))
                    logging.info(f"New best model saved: {best_model_path}")

        except Exception as e:
            logging.error(f"Error in training iteration {iteration}: {e}")
            continue

    logging.info("Training completed!")

    # Save final model
    final_model_path = output_dir / 'final_model.pt'
    trainer.save_checkpoint(str(final_model_path))
    logging.info(f"Final model saved: {final_model_path}")

    # Save training statistics
    final_stats = trainer.get_training_stats()
    stats_path = output_dir / 'training_stats.json'

    import json
    with open(stats_path, 'w') as f:
        json.dump(final_stats, f, indent=2, default=str)

    logging.info(f"Training statistics saved: {stats_path}")

    return trainer


def main():
    parser = argparse.ArgumentParser(description='Train MotifAgent')

    parser.add_argument('--config', type=str, help='Path to configuration file')
    parser.add_argument('--output-dir', type=str, default='output',
                       help='Output directory for models and logs')
    parser.add_argument('--mode', type=str, choices=['reconstruction', 'generation'],
                       default='reconstruction', help='Training mode')
    parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
    parser.add_argument('--data', type=str, help='Path to training data')
    parser.add_argument('--iterations', type=int, help='Number of training iterations')
    parser.add_argument('--device', type=str, choices=['auto', 'cpu', 'cuda'],
                       default='auto', help='Device to use')

    args = parser.parse_args()

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load configuration
    if args.config:
        config = MotifAgentConfig.load(args.config)
    else:
        if args.mode == 'reconstruction':
            config = create_reconstruction_config()
        else:
            config = create_generation_config()

    # Override config with command line arguments
    if args.data:
        config.data.train_data_path = args.data
    if args.iterations:
        config.training.num_iterations = args.iterations
    if args.device != 'auto':
        config.system.device = args.device

    # Validate configuration
    errors = config.validate()
    if errors:
        logging.error("Configuration validation failed:")
        for error in errors:
            logging.error(f"  - {error}")
        sys.exit(1)

    # Setup logging
    setup_logging(config)

    # Save config to output directory
    config_path = output_dir / 'config.json'
    config.save(str(config_path))
    logging.info(f"Configuration saved: {config_path}")

    # Train model
    trainer = train_model(config, output_dir)

    logging.info("Training completed successfully!")


if __name__ == '__main__':
    main()