#!/usr/bin/env python3

import argparse
import logging
import sys
import torch
from pathlib import Path
from typing import List, Dict, Any, Optional

from motifagent.config import MotifAgentConfig
from motifagent.agents.actor import LLMActor
from motifagent.agents.critic import CentralizedCritic
from motifagent.agents.coordinator import CentralizedCoordinator
from motifagent.inference.reconstruction import MolecularReconstructor
from motifagent.inference.generation import MolecularGenerator
from motifagent.utils.io import DataLoader


def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[logging.StreamHandler(sys.stdout)]
    )


def load_model(checkpoint_path: str, config: MotifAgentConfig, device: torch.device):
    logging.info(f"Loading model from {checkpoint_path}")

    # Create models
    actor = LLMActor(
        model_name=config.model.llm_model_name,
        max_length=config.model.max_length
    ).to(device)

    critic = CentralizedCritic(
        llm_model_name=config.model.llm_model_name,
        hidden_dim=config.model.hidden_dim,
        max_motifs=config.model.max_motifs
    ).to(device)

    coordinator = CentralizedCoordinator()

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    actor.load_state_dict(checkpoint['actor_state_dict'])
    critic.load_state_dict(checkpoint['critic_state_dict'])

    actor.eval()
    critic.eval()

    return actor, critic, coordinator


def run_reconstruction(actor, critic, coordinator, smiles: str, config: MotifAgentConfig,
                      output_dir: Path) -> Dict[str, Any]:
    """Run molecule reconstruction for a single SMILES"""
    logging.info(f"Reconstructing molecule: {smiles}")

    reconstructor = MolecularReconstructor(actor, coordinator)

    result = reconstructor.reconstruct_molecule(
        target_smiles=smiles,
        max_steps=config.environment.max_steps,
        temperature=config.inference.temperature,
        beam_width=config.inference.beam_width
    )

    # Save detailed results
    result_data = {
        'input_smiles': smiles,
        'success': result.success,
        'steps_taken': result.steps_taken,
        'accuracy_metrics': result.accuracy_metrics,
        'reasoning_log': result.reasoning_log,
        'topology_analysis': result.topology_analysis
    }

    output_file = output_dir / f"reconstruction_{smiles.replace('/', '_').replace('\\', '_')[:50]}.json"

    import json
    with open(output_file, 'w') as f:
        json.dump(result_data, f, indent=2, default=str)

    logging.info(f"Reconstruction result saved to {output_file}")

    return result_data


def run_batch_reconstruction(actor, critic, coordinator, smiles_list: List[str],
                           config: MotifAgentConfig, output_dir: Path) -> Dict[str, Any]:
    """Run reconstruction for multiple SMILES"""
    logging.info(f"Running batch reconstruction for {len(smiles_list)} molecules")

    reconstructor = MolecularReconstructor(actor, coordinator)

    results = reconstructor.batch_reconstruct(
        smiles_list,
        max_steps=config.environment.max_steps,
        temperature=config.inference.temperature,
        beam_width=config.inference.beam_width
    )

    # Calculate statistics
    stats = reconstructor.get_reconstruction_statistics(results)

    # Save results
    batch_results = {
        'input_smiles': smiles_list,
        'num_molecules': len(smiles_list),
        'statistics': stats,
        'individual_results': [
            {
                'smiles': smiles,
                'success': result.success,
                'steps_taken': result.steps_taken,
                'accuracy_metrics': result.accuracy_metrics
            }
            for smiles, result in zip(smiles_list, results)
        ]
    }

    output_file = output_dir / "batch_reconstruction_results.json"

    import json
    with open(output_file, 'w') as f:
        json.dump(batch_results, f, indent=2, default=str)

    logging.info(f"Batch reconstruction results saved to {output_file}")
    logging.info(f"Success rate: {stats.get('success_rate', 0):.2%}")

    return batch_results


def run_generation(actor, critic, coordinator, num_molecules: int, target_properties: Optional[Dict[str, float]],
                  config: MotifAgentConfig, output_dir: Path) -> Dict[str, Any]:
    """Run molecule generation"""
    logging.info(f"Generating {num_molecules} molecules")
    if target_properties:
        logging.info(f"Target properties: {target_properties}")

    generator = MolecularGenerator(actor, coordinator)

    if target_properties:
        # Use guided generation
        results = generator.guided_generation(
            target_properties=target_properties,
            num_attempts=num_molecules,
            max_steps=config.environment.max_steps,
            temperature=config.inference.temperature
        )
    else:
        # Use standard generation
        results = generator.batch_generate(
            num_molecules=num_molecules,
            target_properties=None,
            max_steps=config.environment.max_steps,
            temperature=config.inference.temperature
        )

    # Calculate statistics
    stats = generator.get_generation_statistics(results)

    # Save results
    generation_results = {
        'num_molecules': num_molecules,
        'target_properties': target_properties,
        'statistics': stats,
        'individual_results': [
            {
                'success': result.success,
                'steps_taken': result.steps_taken,
                'properties': result.properties,
                'generation_metrics': result.generation_metrics
            }
            for result in results
        ]
    }

    output_file = output_dir / "generation_results.json"

    import json
    with open(output_file, 'w') as f:
        json.dump(generation_results, f, indent=2, default=str)

    logging.info(f"Generation results saved to {output_file}")
    logging.info(f"Success rate: {stats.get('success_rate', 0):.2%}")

    return generation_results


def run_interactive_mode(actor, critic, coordinator, config: MotifAgentConfig, output_dir: Path):
    """Run interactive inference mode"""
    logging.info("Starting interactive mode. Type 'quit' to exit.")

    reconstructor = MolecularReconstructor(actor, coordinator)
    generator = MolecularGenerator(actor, coordinator)

    while True:
        try:
            print("\nChoose an option:")
            print("1. Reconstruct molecule (provide SMILES)")
            print("2. Generate molecules")
            print("3. Generate with target properties")
            print("4. Quit")

            choice = input("Enter choice (1-4): ").strip()

            if choice == '1':
                smiles = input("Enter SMILES: ").strip()
                if smiles:
                    result_data = run_reconstruction(
                        actor, critic, coordinator, smiles, config, output_dir
                    )
                    print(f"Reconstruction {'successful' if result_data['success'] else 'failed'}")
                    print(f"Steps taken: {result_data['steps_taken']}")
                    if result_data['accuracy_metrics']:
                        print(f"Edge F1 score: {result_data['accuracy_metrics'].get('edge_f1', 0):.3f}")

            elif choice == '2':
                try:
                    num_molecules = int(input("Number of molecules to generate: ").strip())
                    result_data = run_generation(
                        actor, critic, coordinator, num_molecules, None, config, output_dir
                    )
                    print(f"Generated {num_molecules} molecules")
                    print(f"Success rate: {result_data['statistics'].get('success_rate', 0):.2%}")
                except ValueError:
                    print("Invalid number entered")

            elif choice == '3':
                try:
                    num_molecules = int(input("Number of molecules to generate: ").strip())
                    print("Enter target properties (press Enter to skip):")

                    target_props = {}
                    mw = input("Molecular weight: ").strip()
                    if mw:
                        target_props['molecular_weight'] = float(mw)

                    logp = input("LogP: ").strip()
                    if logp:
                        target_props['logp'] = float(logp)

                    result_data = run_generation(
                        actor, critic, coordinator, num_molecules, target_props, config, output_dir
                    )
                    print(f"Generated {num_molecules} molecules with target properties")
                    print(f"Success rate: {result_data['statistics'].get('success_rate', 0):.2%}")
                except ValueError:
                    print("Invalid input entered")

            elif choice == '4':
                break

            else:
                print("Invalid choice")

        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"Error: {e}")


def main():
    parser = argparse.ArgumentParser(description='MotifAgent Inference')

    parser.add_argument('--checkpoint', type=str, required=True,
                       help='Path to model checkpoint')
    parser.add_argument('--config', type=str, help='Path to configuration file')
    parser.add_argument('--mode', type=str,
                       choices=['reconstruction', 'generation', 'interactive'],
                       default='interactive', help='Inference mode')

    # Reconstruction options
    parser.add_argument('--smiles', type=str, help='SMILES string to reconstruct')
    parser.add_argument('--smiles-file', type=str, help='File containing SMILES strings')

    # Generation options
    parser.add_argument('--num-molecules', type=int, default=10,
                       help='Number of molecules to generate')
    parser.add_argument('--target-mw', type=float, help='Target molecular weight')
    parser.add_argument('--target-logp', type=float, help='Target LogP')

    # General options
    parser.add_argument('--output-dir', type=str, default='inference_results',
                       help='Output directory for results')
    parser.add_argument('--temperature', type=float, default=0.1,
                       help='Sampling temperature')
    parser.add_argument('--beam-width', type=int, default=1,
                       help='Beam search width for reconstruction')
    parser.add_argument('--device', type=str, choices=['auto', 'cpu', 'cuda'],
                       default='auto', help='Device to use')

    args = parser.parse_args()

    setup_logging()

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load configuration
    if args.config:
        config = MotifAgentConfig.load(args.config)
    else:
        config = MotifAgentConfig()

    # Override config with command line arguments
    if args.device != 'auto':
        config.system.device = args.device
    if args.temperature:
        config.inference.temperature = args.temperature
    if args.beam_width:
        config.inference.beam_width = args.beam_width

    device = config.get_device()
    logging.info(f"Using device: {device}")

    # Load model
    actor, critic, coordinator = load_model(args.checkpoint, config, device)

    # Run inference based on mode
    if args.mode == 'reconstruction':
        if args.smiles:
            # Single molecule reconstruction
            run_reconstruction(actor, critic, coordinator, args.smiles, config, output_dir)

        elif args.smiles_file:
            # Batch reconstruction
            smiles_data = DataLoader.load_smiles_dataset(args.smiles_file)
            smiles_list = [entry['smiles'] for entry in smiles_data]
            run_batch_reconstruction(actor, critic, coordinator, smiles_list, config, output_dir)

        else:
            logging.error("For reconstruction mode, provide either --smiles or --smiles-file")
            sys.exit(1)

    elif args.mode == 'generation':
        # Prepare target properties
        target_properties = {}
        if args.target_mw:
            target_properties['molecular_weight'] = args.target_mw
        if args.target_logp:
            target_properties['logp'] = args.target_logp

        target_properties = target_properties if target_properties else None

        run_generation(
            actor, critic, coordinator, args.num_molecules, target_properties, config, output_dir
        )

    elif args.mode == 'interactive':
        # Interactive mode
        run_interactive_mode(actor, critic, coordinator, config, output_dir)

    logging.info("Inference completed!")


if __name__ == '__main__':
    main()