import os
import pickle
import json
import torch
import numpy as np
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional
import logging

logger = logging.getLogger(__name__)

class ModelPersistenceManager:
    
    def __init__(self, models_dir: str = "models"):
        self.models_dir = Path(models_dir)
        self.models_dir.mkdir(exist_ok=True)
        
        self.qlearning_path = self.models_dir / "qlearning_model.pkl"
        self.hybrid_neural_path = self.models_dir / "hybrid_neural_model.pth"
        self.metadata_path = self.models_dir / "training_metadata.json"
        self.vocab_path = self.models_dir / "concept_vocabulary.json"
        self.training_history_path = self.models_dir / "training_history.pkl"
        
        logger.info(f"Model Persistence Manager initialized - models dir: {self.models_dir}")
    
    def save_complete_trained_system(self, system_results: Dict) -> Dict[str, str]:
        
        print("Saving complete trained system...")
        saved_files = {}
        
        try:
            if 'qlearning_system' in system_results:
                self.save_qlearning_system(system_results['qlearning_system'])
                saved_files['qlearning'] = str(self.qlearning_path)
                print("   Q-Learning system saved")
            
            if 'hybrid_system' in system_results:
                self.save_hybrid_neural_system(system_results['hybrid_system'])
                saved_files['hybrid_neural'] = str(self.hybrid_neural_path)
                print("   Hybrid Neural system saved")
            
            if 'neural_training_history' in system_results:
                self.save_training_history(system_results['neural_training_history'])
                saved_files['training_history'] = str(self.training_history_path)
                print("   Training history saved")
            
            metadata = self.create_training_metadata(system_results)
            self.save_metadata(metadata)
            saved_files['metadata'] = str(self.metadata_path)
            print("   Training metadata saved")
            
            print(f"Complete system saved to {self.models_dir}")
            return saved_files
            
        except Exception as e:
            print(f"Error saving system: {e}")
            return {}
    
    def save_qlearning_system(self, qlearning_system):
        
        qlearning_data = {
            'Q_matrix': qlearning_system.Q.tolist() if qlearning_system.Q is not None else None,
            'medical_concepts': qlearning_system.medical_concepts,
            'concept_to_index': qlearning_system.concept_to_index,
            'index_to_concept': qlearning_system.index_to_concept,
            'concept_frequencies': dict(qlearning_system.concept_frequencies),
            'concept_cooccurrences': {
                k: dict(v) for k, v in qlearning_system.concept_cooccurrences.items()
            },
            'mimic_cases_used': qlearning_system.mimic_cases_used,
            'learning_params': {
                'alpha': qlearning_system.alpha,
                'gamma': qlearning_system.gamma,
                'epsilon': qlearning_system.epsilon
            },
            'timestamp': datetime.now().isoformat()
        }
        
        with open(self.qlearning_path, 'wb') as f:
            pickle.dump(qlearning_data, f)
    
    def save_hybrid_neural_system(self, hybrid_system):
        
        if hybrid_system.neural_model is not None:
            torch.save({
                'model_state_dict': hybrid_system.neural_model.state_dict(),
                'model_config': {
                    'vocab_size': len(hybrid_system.concept_vocab),
                    'embedding_dim': hybrid_system.neural_model.embedding_dim,
                    'model_class': 'MedicalConceptEmbedding'
                },
                'concept_vocab': hybrid_system.concept_vocab,
                'training_examples_count': len(hybrid_system.training_examples),
                'timestamp': datetime.now().isoformat()
            }, self.hybrid_neural_path)
        
        with open(self.vocab_path, 'w') as f:
            json.dump(hybrid_system.concept_vocab, f, indent=2)
    
    def save_training_history(self, training_history: List[Dict]):
        
        training_data = {
            'history': training_history,
            'total_epochs': len(training_history),
            'final_metrics': training_history[-1] if training_history else {},
            'timestamp': datetime.now().isoformat()
        }
        
        with open(self.training_history_path, 'wb') as f:
            pickle.dump(training_data, f)
    
    def save_metadata(self, metadata: Dict):
        
        with open(self.metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
    
    def create_training_metadata(self, system_results: Dict) -> Dict:
        
        metadata = {
            'training_completed': True,
            'timestamp': datetime.now().isoformat(),
            'system_components': {
                'qlearning_available': 'qlearning_system' in system_results,
                'hybrid_neural_available': 'hybrid_system' in system_results,
                'rl_feedback_available': 'feedback_loop' in system_results,
                'dynamic_agents_available': 'agent_creator' in system_results,
                'rag_llm_integrated': system_results.get('rag_llm_integrated', False),
                'medical_nlp_enabled': system_results.get('medical_nlp_enabled', False)
            },
            'training_data': {
                'patient_cases_count': len(system_results.get('patient_cases', [])),
                'real_mimic_data': system_results.get('real_mimic_data', False),
                'pure_data_driven': system_results.get('pure_data_driven', False)
            },
            'model_performance': {
                'vocab_size': system_results.get('vocab_size', 0),
                'training_epochs': len(system_results.get('neural_training_history', [])),
                'final_loss': system_results.get('neural_training_history', [{}])[-1].get('total_loss', 0) if system_results.get('neural_training_history') else 0
            }
        }
        
        return metadata
    
    def check_trained_models_exist(self) -> Dict[str, bool]:
        
        model_status = {
            'qlearning_model': self.qlearning_path.exists(),
            'hybrid_neural_model': self.hybrid_neural_path.exists(),
            'concept_vocabulary': self.vocab_path.exists(),
            'training_metadata': self.metadata_path.exists(),
            'training_history': self.training_history_path.exists()
        }
        
        all_exist = all(model_status.values())
        
        print(f"Checking for existing trained models:")
        for model_name, exists in model_status.items():
            status = "Present" if exists else "Missing"
            print(f"   {status}: {model_name}")
        
        if all_exist:
            print(f"All trained models found! Can skip training.")
        else:
            print(f"Some models missing - training required.")
        
        return model_status
    
    def load_complete_trained_system(self):
        
        print("Loading complete trained system...")
        
        model_status = self.check_trained_models_exist()
        if not all(model_status.values()):
            print("Cannot load - some models missing")
            return None
        
        try:
            metadata = self.load_metadata()
            print(f"   Training metadata: {metadata['timestamp']}")
            
            qlearning_system = self.load_qlearning_system()
            print(f"   Q-Learning loaded: {len(qlearning_system.medical_concepts)} concepts")
            
            hybrid_system = self.load_hybrid_neural_system()
            print(f"   Hybrid Neural loaded: vocab size {len(hybrid_system.concept_vocab)}")
            
            training_history = self.load_training_history()
            print(f"   Training history loaded: {len(training_history)} epochs")
            
            loaded_system = {
                'qlearning_system': qlearning_system,
                'hybrid_system': hybrid_system,
                'neural_training_history': training_history,
                'training_complete': True,
                'loaded_from_disk': True,
                'metadata': metadata
            }
            
            print(f"Complete system loaded successfully!")
            return loaded_system
            
        except Exception as e:
            print(f"Error loading system: {e}")
            return None
    
    def load_qlearning_system(self):
        
        from pure_medical_qlearning import PureMedicalQLearning
        
        with open(self.qlearning_path, 'rb') as f:
            qlearning_data = pickle.load(f)
        
        qlearning_system = PureMedicalQLearning()
        
        qlearning_system.Q = np.array(qlearning_data['Q_matrix']) if qlearning_data['Q_matrix'] else None
        qlearning_system.medical_concepts = qlearning_data['medical_concepts']
        qlearning_system.concept_to_index = qlearning_data['concept_to_index']
        qlearning_system.index_to_concept = qlearning_data['index_to_concept']
        qlearning_system.concept_frequencies = qlearning_data['concept_frequencies']
        qlearning_system.concept_cooccurrences = qlearning_data['concept_cooccurrences']
        qlearning_system.mimic_cases_used = qlearning_data['mimic_cases_used']
        
        params = qlearning_data['learning_params']
        qlearning_system.alpha = params['alpha']
        qlearning_system.gamma = params['gamma']
        qlearning_system.epsilon = params['epsilon']
        
        return qlearning_system
    
    def load_hybrid_neural_system(self):
        
        from hybrid_neural_system import HybridNeuralSymbolicSystem, MedicalConceptEmbedding
        
        checkpoint = torch.load(self.hybrid_neural_path, map_location='cpu')
        
        hybrid_system = HybridNeuralSymbolicSystem()
        hybrid_system.concept_vocab = checkpoint['concept_vocab']
        
        model_config = checkpoint['model_config']
        neural_model = MedicalConceptEmbedding(
            vocab_size=model_config['vocab_size'],
            embedding_dim=256,
            hidden_dim=512
        )
        
        neural_model.load_state_dict(checkpoint['model_state_dict'])
        hybrid_system.neural_model = neural_model
        
        return hybrid_system
    
    def load_training_history(self) -> List[Dict]:
        
        with open(self.training_history_path, 'rb') as f:
            training_data = pickle.load(f)
        
        return training_data['history']
    
    def load_metadata(self) -> Dict:
        
        with open(self.metadata_path, 'r') as f:
            return json.load(f)
    
    def get_model_info(self) -> Dict:
        
        if not self.metadata_path.exists():
            return {'available': False}
        
        metadata = self.load_metadata()
        
        return {
            'available': True,
            'training_timestamp': metadata['timestamp'],
            'patient_cases_used': metadata['training_data']['patient_cases_count'],
            'vocab_size': metadata['model_performance']['vocab_size'],
            'training_epochs': metadata['model_performance']['training_epochs'],
            'components_available': metadata['system_components']
        }

def enhance_main_system_with_persistence(main_system_file: str = "working_main_system.py"):
    
    print("Enhancing main system with model persistence...")
    
    enhancement_code = '''
from model_persistence import ModelPersistenceManager

def run_with_model_persistence():
    
    print("ENHANCED SYSTEM WITH MODEL PERSISTENCE")
    print("="*70)
    
    persistence = ModelPersistenceManager()
    
    model_status = persistence.check_trained_models_exist()
    
    if all(model_status.values()):
        print("TRAINED MODELS FOUND - LOADING FROM DISK (NO RETRAINING NEEDED)")
        
        loaded_system = persistence.load_complete_trained_system()
        
        if loaded_system:
            print("System loaded successfully - ready for use!")
            return loaded_system
        else:
            print("Loading failed - will retrain")
    
    print("No trained models found - starting training...")
    
    training_results = run_complete_mimic_system()
    
    if training_results and training_results.get('training_complete'):
        print("Saving trained models for future use...")
        saved_files = persistence.save_complete_trained_system(training_results)
        
        if saved_files:
            print("Models saved successfully!")
            training_results['model_files_saved'] = saved_files
        else:
            print("Model saving failed")
    
    return training_results

if __name__ == "__main__":
    print("Starting PERSISTENCE-ENHANCED SYSTEM...")
    results = run_with_model_persistence()
    
    if results and results.get('training_complete'):
        if results.get('loaded_from_disk'):
            print("SYSTEM LOADED FROM DISK - Ready for immediate use!")
        else:
            print("SYSTEM TRAINED AND SAVED - Future runs will be instant!")
        
        persistence = ModelPersistenceManager()
        model_info = persistence.get_model_info()
        print(f"Model Info: {model_info}")
    else:
        print("System startup failed")
'''
    
    print("Enhancement code generated")
    print("Copy the above code into your working_main_system.py")
    
    return enhancement_code

def test_model_persistence():
    
    print("Testing Model Persistence")
    
    persistence = ModelPersistenceManager()
    
    model_status = persistence.check_trained_models_exist()
    print(f"Model status: {model_status}")
    
    model_info = persistence.get_model_info()
    print(f"Model info: {model_info}")
    
    return persistence

if __name__ == "__main__":
    test_model_persistence()