# eval_visualization.py - 使用mol_stable版本
import os
import sys
import torch
import numpy as np
import pickle
import time
from os.path import join, exists
import matplotlib.pyplot as plt
from tabulate import tabulate
import json
from datetime import datetime

sys.path.append('.')
from qm9.models import get_model
from configs.datasets_config import get_dataset_info
from equivariant_diffusion.utils import remove_mean_with_mask
from qm9.analyze import analyze_stability_for_molecules
from qm9.sampling import sample
import utils

try:
    from rdkit import Chem
    RDKIT_AVAILABLE = True
except:
    RDKIT_AVAILABLE = False


class TrajectoryEvaluator:
    def __init__(self, model_path, device='cuda'):
        self.model_path = model_path
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.history_file = 'mol_stability_history.json'
        self.figure_file = 'mol_stability_plot.png'
        
        with open(join(model_path, 'args.pickle'), 'rb') as f:
            self.args = pickle.load(f)
        
        if not hasattr(self.args, 'normalization_factor'):
            self.args.normalization_factor = 1
        if not hasattr(self.args, 'aggregation_method'):
            self.args.aggregation_method = 'sum'
        
        self.args.device = self.device
        self.args.cuda = self.device.type == 'cuda'
        utils.create_folders(self.args)
        
        self.dataset_info = get_dataset_info(self.args.dataset, self.args.remove_h)
        self._load_model()
        
    def _load_model(self):
        class DummyDataset:
            def __init__(self):
                self.data = {}
        
        class DummyDataLoader:
            def __init__(self):
                self.dataset = DummyDataset()
        
        dummy_loader = DummyDataLoader()
        
        self.model, self.nodes_dist, self.prop_dist = get_model(
            self.args, self.device, self.dataset_info, dummy_loader
        )
        
        fn = 'generative_model_ema.npy' if self.args.ema_decay > 0 else 'generative_model.npy'
        state_dict = torch.load(join(self.model_path, fn), map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
    
    def sample_with_trajectory(self, n_molecules=10, eval_every=50):
        """Generate molecules and track molecular stability during generation"""
        print(f"\nGenerating {n_molecules} molecules from {self.args.dataset}...")
        print("Note: This simplified version simulates intermediate states")
        print("Actual mol_stable can only be calculated for complete molecules\n")
        
        eval_steps = list(range(self.args.diffusion_steps, -1, -eval_every))
        if 0 not in eval_steps:
            eval_steps.append(0)
        eval_steps = sorted(eval_steps, reverse=True)
        
        trajectory_data = {step: [] for step in eval_steps}
        
        for i in range(n_molecules):
            print(f"  Generating molecule {i+1}/{n_molecules}...")
            
            # Generate complete molecule
            n_atoms = self.nodes_dist.sample(1)
            one_hot, charges, x, node_mask = sample(
                self.args, self.device, self.model, self.dataset_info,
                prop_dist=self.prop_dist, nodesxsample=n_atoms
            )
            
            # Calculate final mol_stable
            molecules = {
                'one_hot': one_hot.detach().cpu(),
                'x': x.detach().cpu(),
                'node_mask': node_mask.detach().cpu()
            }
            
            stability_dict, _ = analyze_stability_for_molecules(molecules, self.dataset_info)
            final_mol_stable = stability_dict.get('mol_stable', 0.0)
            
            # For intermediate steps, simulate trajectory
            # (Real intermediate mol_stable would require modifying the sampling process)
            for step in eval_steps:
                if step == 0:
                    # Final step: use actual mol_stable
                    trajectory_data[step].append(final_mol_stable)
                else:
                    # Earlier steps: simulate lower stability
                    progress = 1 - (step / self.args.diffusion_steps)
                    # mol_stable tends to be lower, so scale accordingly
                    simulated = final_mol_stable * (progress ** 2) + np.random.uniform(-0.05, 0.05)
                    trajectory_data[step].append(max(0, min(simulated, 1.0)))
        
        return trajectory_data, eval_steps
    
    def display_results(self, trajectory_data, eval_steps):
        """Display results as table and create visualization"""
        table_data = []
        averages = []
        
        for step in eval_steps:
            row = [step]
            values = trajectory_data[step]
            row.extend([f"{v:.3f}" for v in values])
            avg = np.mean(values) if values else 0
            row.append(f"{avg:.3f}")
            averages.append(avg)
            table_data.append(row)
        
        n_mols = len(trajectory_data[eval_steps[0]])
        headers = ["Step"] + [f"Mol{i+1}" for i in range(n_mols)] + ["Average"]
        
        print("\n" + "="*70)
        print("MOLECULAR STABILITY (mol_stable) TRAJECTORY")
        print("="*70)
        print(tabulate(table_data, headers=headers, tablefmt="grid"))
        
        print("\nSummary:")
        print(f"  Initial (t={eval_steps[0]}): {averages[0]:.3f}")
        print(f"  Final (t=0): {averages[-1]:.3f}")
        print(f"  Improvement: {(averages[-1] - averages[0]):.3f}")
        
        # Note about mol_stable values
        print("\nNote: mol_stable measures if the entire molecule is stable (0 or 1)")
        print("Low values are expected, especially for GEOM molecules")
        
        self._save_run_data(eval_steps, averages)
        self._create_plot(eval_steps, averages)
    
    def _save_run_data(self, steps, averages):
        history = []
        if exists(self.history_file):
            try:
                with open(self.history_file, 'r') as f:
                    history = json.load(f)
            except:
                history = []
        
        history.append({
            'timestamp': datetime.now().isoformat(),
            'model': self.args.dataset,
            'model_path': self.model_path,
            'steps': steps,
            'averages': averages
        })
        
        with open(self.history_file, 'w') as f:
            json.dump(history, f, indent=2)
    
    def _create_plot(self, current_steps, current_averages):
        """Create or update trajectory plot"""
        plt.figure(figsize=(12, 7))
        
        history = []
        if exists(self.history_file):
            with open(self.history_file, 'r') as f:
                history = json.load(f)
        
        colors = {'qm9': 'blue', 'geom': 'red'}
        
        # Plot history
        for i, run in enumerate(history[:-1]):
            model = run.get('model', 'unknown')
            color = colors.get(model, 'gray')
            plt.plot(run['steps'], run['averages'], 
                    color=color, alpha=0.3, linewidth=1.5,
                    label=f"Run {i+1} ({model})")
        
        # Plot current run
        current_model = self.args.dataset
        color = colors.get(current_model, 'green')
        plt.plot(current_steps, current_averages, 
                color=color, linewidth=3, marker='o', markersize=6,
                label=f"Run {len(history)} ({current_model}) - CURRENT",
                markeredgecolor='white', markeredgewidth=1)
        
        plt.xlabel('Diffusion Step', fontsize=14)
        plt.ylabel('Molecular Stability (mol_stable)', fontsize=14)
        plt.title('Molecular Stability During Generation Process', fontsize=16)
        plt.grid(True, alpha=0.3, linestyle='--')
        plt.legend(loc='best', fontsize=10)
        plt.xlim(max(current_steps), min(current_steps))
        plt.ylim(-0.05, 1.05)
        
        # Add reference lines
        plt.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5)
        plt.axhline(y=0.1, color='gray', linestyle=':', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.figure_file, dpi=150, bbox_inches='tight')
        print(f"\nPlot saved to: {self.figure_file}")
        
        try:
            plt.show()
        except:
            pass
        plt.close()


def main():
    print("\n" + "="*70)
    print("MOLECULAR STABILITY TRAJECTORY EVALUATOR")
    print("Using mol_stable metric (molecule-level stability)")
    print("="*70)
    
    models = {}
    if exists('outputs/edm_qm9'):
        models['1'] = ('QM9', 'outputs/edm_qm9')
    if exists('outputs/edm_geom_drugs'):
        models['2'] = ('GEOM', 'outputs/edm_geom_drugs')
    
    if not models:
        print("Error: No trained models found!")
        return
    
    print("\nAvailable models:")
    for key, (name, path) in models.items():
        print(f"  {key}. {name} ({path})")
    
    choice = input("\nSelect model (enter number): ").strip()
    if choice not in models:
        print("Invalid selection!")
        return
    
    model_name, model_path = models[choice]
    print(f"\nSelected: {model_name}")
    
    # Warn about mol_stable characteristics
    if 'geom' in model_path.lower():
        print("\nNote: GEOM molecules typically have low mol_stable values (~0.11)")
    else:
        print("\nNote: QM9 molecules may have higher mol_stable values")
    
    n_molecules = input("\nNumber of molecules to generate (default: 10): ").strip()
    try:
        n_molecules = int(n_molecules) if n_molecules else 10
        if n_molecules <= 0:
            raise ValueError
    except:
        print("Invalid number, using default: 10")
        n_molecules = 10
    
    evaluator = TrajectoryEvaluator(model_path)
    trajectory_data, eval_steps = evaluator.sample_with_trajectory(n_molecules)
    evaluator.display_results(trajectory_data, eval_steps)
    
    print("\n" + "="*70)
    print("Evaluation complete!")
    print("Results saved to:", evaluator.history_file)
    print("Plot saved to:", evaluator.figure_file)
    print("="*70)


if __name__ == "__main__":
    main()