#!/usr/bin/env python3

import argparse
import logging
import sys
import torch
from pathlib import Path
from typing import List, Dict, Any

from motifagent.config import MotifAgentConfig
from motifagent.agents.actor import LLMActor
from motifagent.agents.critic import CentralizedCritic
from motifagent.agents.coordinator import CentralizedCoordinator
from motifagent.inference.reconstruction import MolecularReconstructor
from motifagent.inference.generation import MolecularGenerator
from motifagent.core.segmentation import BRICSSegmentation
from motifagent.utils.io import DataLoader
from motifagent.utils.chemistry import MolecularUtils


def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[logging.StreamHandler(sys.stdout)]
    )


def load_model(checkpoint_path: str, config: MotifAgentConfig, device: torch.device):
    logging.info(f"Loading model from {checkpoint_path}")

    # Create models
    actor = LLMActor(
        model_name=config.model.llm_model_name,
        max_length=config.model.max_length
    ).to(device)

    critic = CentralizedCritic(
        llm_model_name=config.model.llm_model_name,
        hidden_dim=config.model.hidden_dim,
        max_motifs=config.model.max_motifs
    ).to(device)

    coordinator = CentralizedCoordinator()

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    actor.load_state_dict(checkpoint['actor_state_dict'])
    critic.load_state_dict(checkpoint['critic_state_dict'])

    actor.eval()
    critic.eval()

    return actor, critic, coordinator


def evaluate_reconstruction(actor, critic, coordinator, test_data: List[Dict[str, Any]],
                          config: MotifAgentConfig) -> Dict[str, Any]:
    logging.info("Evaluating reconstruction performance...")

    reconstructor = MolecularReconstructor(actor, coordinator)

    results = []
    for i, entry in enumerate(test_data):
        logging.info(f"Reconstructing molecule {i+1}/{len(test_data)}: {entry['smiles']}")

        try:
            result = reconstructor.reconstruct_molecule(
                target_smiles=entry['smiles'],
                max_steps=config.environment.max_steps,
                temperature=config.inference.temperature,
                beam_width=config.inference.beam_width
            )
            results.append(result)
        except Exception as e:
            logging.error(f"Error reconstructing {entry['smiles']}: {e}")
            continue

    # Calculate statistics
    stats = reconstructor.get_reconstruction_statistics(results)

    return {
        'results': results,
        'statistics': stats,
        'num_molecules': len(test_data),
        'successful_reconstructions': len([r for r in results if r.success])
    }


def evaluate_generation(actor, critic, coordinator, num_molecules: int,
                       config: MotifAgentConfig) -> Dict[str, Any]:
    logging.info(f"Evaluating generation performance for {num_molecules} molecules...")

    generator = MolecularGenerator(actor, coordinator)

    results = generator.batch_generate(
        num_molecules=num_molecules,
        target_properties=None,  # Could be specified from config
        max_steps=config.environment.max_steps,
        temperature=config.inference.temperature
    )

    # Calculate statistics
    stats = generator.get_generation_statistics(results)

    return {
        'results': results,
        'statistics': stats,
        'num_molecules': num_molecules
    }


def evaluate_property_optimization(actor, critic, coordinator, target_properties: Dict[str, float],
                                 num_attempts: int, config: MotifAgentConfig) -> Dict[str, Any]:
    logging.info(f"Evaluating property optimization for {target_properties}")

    generator = MolecularGenerator(actor, coordinator)

    results = generator.guided_generation(
        target_properties=target_properties,
        num_attempts=num_attempts,
        max_steps=config.environment.max_steps,
        temperature=config.inference.temperature
    )

    # Analyze property alignment
    property_scores = []
    for result in results:
        if result.success:
            score = 0.0
            for prop, target_val in target_properties.items():
                if prop in result.properties:
                    actual_val = result.properties[prop]
                    error = abs(actual_val - target_val) / max(abs(target_val), 1.0)
                    score += max(0, 1 - error)

            property_scores.append(score / len(target_properties))

    return {
        'results': results,
        'target_properties': target_properties,
        'num_successful': len([r for r in results if r.success]),
        'property_alignment_scores': property_scores,
        'mean_property_score': sum(property_scores) / len(property_scores) if property_scores else 0.0
    }


def analyze_molecular_diversity(results: List[Any]) -> Dict[str, float]:
    """Analyze diversity of generated molecules"""
    if not results:
        return {}

    # Extract SMILES from successful results
    smiles_list = []
    for result in results:
        if result.success and hasattr(result, 'generated_graph'):
            # Convert graph to SMILES (simplified)
            # In practice, would need proper graph-to-SMILES conversion
            smiles_list.append(f"generated_{len(smiles_list)}")  # Placeholder

    if not smiles_list:
        return {'diversity_error': 'No valid SMILES found'}

    # Calculate diversity metrics
    unique_molecules = len(set(smiles_list))
    diversity_ratio = unique_molecules / len(smiles_list) if smiles_list else 0.0

    # Calculate Tanimoto diversity (simplified)
    tanimoto_similarities = []
    for i, smiles1 in enumerate(smiles_list):
        for j, smiles2 in enumerate(smiles_list[i+1:], i+1):
            mol1 = MolecularUtils.smiles_to_mol(smiles1)
            mol2 = MolecularUtils.smiles_to_mol(smiles2)

            if mol1 and mol2:
                fp1 = MolecularUtils.calculate_fingerprint(mol1, 'morgan')
                fp2 = MolecularUtils.calculate_fingerprint(mol2, 'morgan')

                if fp1 is not None and fp2 is not None:
                    similarity = MolecularUtils.calculate_tanimoto_similarity(fp1, fp2)
                    tanimoto_similarities.append(similarity)

    mean_similarity = sum(tanimoto_similarities) / len(tanimoto_similarities) if tanimoto_similarities else 0.0
    diversity_score = 1.0 - mean_similarity

    return {
        'total_molecules': len(smiles_list),
        'unique_molecules': unique_molecules,
        'diversity_ratio': diversity_ratio,
        'mean_tanimoto_similarity': mean_similarity,
        'diversity_score': diversity_score
    }


def save_evaluation_results(results: Dict[str, Any], output_path: Path):
    """Save evaluation results to file"""
    import json

    # Convert results to JSON-serializable format
    serializable_results = {}

    for key, value in results.items():
        if key == 'results':
            # Summarize results instead of saving full objects
            serializable_results[f'{key}_summary'] = f"{len(value)} results"
        else:
            try:
                json.dumps(value)  # Test if serializable
                serializable_results[key] = value
            except TypeError:
                serializable_results[key] = str(value)

    with open(output_path, 'w') as f:
        json.dump(serializable_results, f, indent=2, default=str)

    logging.info(f"Evaluation results saved to {output_path}")


def main():
    parser = argparse.ArgumentParser(description='Evaluate MotifAgent')

    parser.add_argument('--checkpoint', type=str, required=True,
                       help='Path to model checkpoint')
    parser.add_argument('--config', type=str, help='Path to configuration file')
    parser.add_argument('--mode', type=str, choices=['reconstruction', 'generation', 'both'],
                       default='both', help='Evaluation mode')
    parser.add_argument('--test-data', type=str, help='Path to test dataset')
    parser.add_argument('--num-molecules', type=int, default=100,
                       help='Number of molecules for generation evaluation')
    parser.add_argument('--output-dir', type=str, default='evaluation_results',
                       help='Output directory for results')
    parser.add_argument('--device', type=str, choices=['auto', 'cpu', 'cuda'],
                       default='auto', help='Device to use')

    args = parser.parse_args()

    setup_logging()

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load configuration
    if args.config:
        config = MotifAgentConfig.load(args.config)
    else:
        config = MotifAgentConfig()

    # Set device
    if args.device != 'auto':
        config.system.device = args.device

    device = config.get_device()
    logging.info(f"Using device: {device}")

    # Load model
    actor, critic, coordinator = load_model(args.checkpoint, config, device)

    evaluation_results = {}

    # Reconstruction evaluation
    if args.mode in ['reconstruction', 'both']:
        if args.test_data:
            test_data = DataLoader.load_smiles_dataset(args.test_data)[:50]  # Limit for demo
            reconstruction_results = evaluate_reconstruction(
                actor, critic, coordinator, test_data, config
            )
            evaluation_results['reconstruction'] = reconstruction_results

            logging.info(f"Reconstruction Results:")
            logging.info(f"  Success Rate: {reconstruction_results['statistics'].get('success_rate', 0):.2%}")
            logging.info(f"  Perfect Match Rate: {reconstruction_results['statistics'].get('perfect_match_rate', 0):.2%}")
            logging.info(f"  Average Edge F1: {reconstruction_results['statistics'].get('average_edge_f1', 0):.3f}")
        else:
            logging.warning("No test data provided for reconstruction evaluation")

    # Generation evaluation
    if args.mode in ['generation', 'both']:
        generation_results = evaluate_generation(
            actor, critic, coordinator, args.num_molecules, config
        )
        evaluation_results['generation'] = generation_results

        logging.info(f"Generation Results:")
        logging.info(f"  Success Rate: {generation_results['statistics'].get('success_rate', 0):.2%}")
        logging.info(f"  Connectivity Rate: {generation_results['statistics'].get('connectivity_rate', 0):.2%}")
        logging.info(f"  Drug-like MW Rate: {generation_results['statistics'].get('druglike_mw_rate', 0):.2%}")

        # Analyze diversity
        diversity_results = analyze_molecular_diversity(generation_results['results'])
        evaluation_results['diversity'] = diversity_results

        logging.info(f"Diversity Results:")
        logging.info(f"  Diversity Ratio: {diversity_results.get('diversity_ratio', 0):.3f}")
        logging.info(f"  Diversity Score: {diversity_results.get('diversity_score', 0):.3f}")

    # Property optimization evaluation (example)
    target_props = {'molecular_weight': 300, 'logp': 2.0}
    property_results = evaluate_property_optimization(
        actor, critic, coordinator, target_props, 20, config
    )
    evaluation_results['property_optimization'] = property_results

    logging.info(f"Property Optimization Results:")
    logging.info(f"  Target Properties: {target_props}")
    logging.info(f"  Success Rate: {property_results['num_successful']}/20")
    logging.info(f"  Mean Property Score: {property_results['mean_property_score']:.3f}")

    # Save results
    results_path = output_dir / 'evaluation_results.json'
    save_evaluation_results(evaluation_results, results_path)

    # Save summary
    summary = {
        'checkpoint': args.checkpoint,
        'evaluation_mode': args.mode,
        'num_test_molecules': len(DataLoader.load_smiles_dataset(args.test_data)) if args.test_data else 0,
        'num_generated_molecules': args.num_molecules,
        'device': str(device),
        'summary_statistics': {
            section: stats.get('statistics', {}) if isinstance(stats, dict) and 'statistics' in stats else {}
            for section, stats in evaluation_results.items()
            if section != 'diversity'
        }
    }

    summary_path = output_dir / 'evaluation_summary.json'
    with open(summary_path, 'w') as f:
        import json
        json.dump(summary, f, indent=2, default=str)

    logging.info(f"Evaluation summary saved to {summary_path}")
    logging.info("Evaluation completed successfully!")


if __name__ == '__main__':
    main()