"""
Offline AI Educational Chatbot System
Complete implementation with synthetic data generation and realistic simulations
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel, AdamW
import numpy as np
import json
import random
import logging
import os
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import sqlite3
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import time

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class EducationalExample:
    """Represents a single educational training example"""
    query: str
    response: str
    subject: str
    grade_level: int
    learning_objectives: List[str]
    curriculum_aligned: bool

class SyntheticEducationalDataGenerator:
    """Generates realistic synthetic educational data for training and evaluation"""
    
    def __init__(self, seed: int = 42):
        random.seed(seed)
        np.random.seed(seed)
        
        # Educational content templates
        self.subjects = ["Mathematics", "Science", "English", "History", "Geography"]
        self.grade_levels = [6, 7, 8, 9, 10, 11, 12]
        
        # Mathematics question templates
        self.math_templates = {
            "algebra": [
                ("What is the value of x in the equation {eq}?", "To solve {eq}, we {steps}. Therefore, x = {answer}."),
                ("Simplify the expression {expr}", "To simplify {expr}, we {process}. The result is {result}."),
                ("Factor the quadratic {quad}", "To factor {quad}, we look for two numbers that {explanation}. The factored form is {factored}.")
            ],
            "calculus": [
                ("Find the derivative of f(x) = {func}", "Using the {rule} rule, the derivative of f(x) = {func} is f'(x) = {derivative}."),
                ("Evaluate the integral ∫{integrand} dx", "To integrate {integrand}, we use {method}. The result is {result} + C."),
                ("Find the limit as x approaches {value} of {expression}", "Evaluating the limit: {process}. The limit is {limit_value}.")
            ]
        }
        
        # Science question templates
        self.science_templates = {
            "physics": [
                ("What is {concept}?", "{concept} is {definition}. For example, {example}."),
                ("How does {phenomenon} work?", "{phenomenon} works through {mechanism}. The key principle is {principle}."),
                ("Calculate the {quantity} given {parameters}", "Using the formula {formula}, we substitute {values} to get {result}.")
            ],
            "chemistry": [
                ("What happens when {reactant1} reacts with {reactant2}?", "When {reactant1} reacts with {reactant2}, {reaction_description}. The chemical equation is {equation}."),
                ("Explain the concept of {chem_concept}", "{chem_concept} refers to {definition}. This is important because {significance}."),
                ("How do you balance the equation {unbalanced}?", "To balance {unbalanced}, we {balancing_steps}. The balanced equation is {balanced}.")
            ]
        }
        
        # Learning objectives database
        self.learning_objectives = {
            "Mathematics": [
                "Solve linear equations", "Factor quadratic expressions", "Apply calculus rules",
                "Interpret graphs", "Use mathematical reasoning", "Apply geometric principles"
            ],
            "Science": [
                "Understand scientific method", "Analyze experimental data", "Apply physics laws",
                "Explain chemical reactions", "Understand atomic structure", "Apply conservation laws"
            ]
        }

    def generate_math_example(self, grade_level: int) -> EducationalExample:
        """Generate a mathematics educational example"""
        if grade_level <= 8:
            topic = "algebra"
        else:
            topic = random.choice(["algebra", "calculus"])
        
        template = random.choice(self.math_templates[topic])
        
        if topic == "algebra":
            # Generate algebra example
            a, b, c = random.randint(1, 10), random.randint(1, 10), random.randint(1, 20)
            equation = f"{a}x + {b} = {c}"
            solution = (c - b) / a
            
            query = template[0].format(eq=equation)
            response = template[1].format(eq=equation, steps=f"subtract {b} from both sides and divide by {a}", answer=solution)
        
        else:  # calculus
            functions = ["x^2", "3x^3", "sin(x)", "cos(x)", "e^x", "ln(x)"]
            func = random.choice(functions)
            
            derivatives = {
                "x^2": "2x", "3x^3": "9x^2", "sin(x)": "cos(x)", 
                "cos(x)": "-sin(x)", "e^x": "e^x", "ln(x)": "1/x"
            }
            
            query = template[0].format(func=func)
            response = template[1].format(func=func, rule="power/chain", derivative=derivatives.get(func, "derivative"))
        
        return EducationalExample(
            query=query,
            response=response,
            subject="Mathematics",
            grade_level=grade_level,
            learning_objectives=random.sample(self.learning_objectives["Mathematics"], 2),
            curriculum_aligned=True
        )

    def generate_science_example(self, grade_level: int) -> EducationalExample:
        """Generate a science educational example"""
        topic = random.choice(["physics", "chemistry"])
        template = random.choice(self.science_templates[topic])
        
        if topic == "physics":
            concepts = ["velocity", "acceleration", "force", "energy", "momentum", "power"]
            concept = random.choice(concepts)
            
            definitions = {
                "velocity": "the rate of change of position with respect to time",
                "acceleration": "the rate of change of velocity with respect to time",
                "force": "an interaction that changes the motion of an object",
                "energy": "the capacity to do work or produce heat",
                "momentum": "the product of mass and velocity",
                "power": "the rate at which work is done or energy is transferred"
            }
            
            query = template[0].format(concept=concept)
            response = template[1].format(
                concept=concept, 
                definition=definitions[concept],
                example=f"when you push a {random.choice(['car', 'box', 'ball'])}"
            )
        
        else:  # chemistry
            reactants = [("sodium", "chlorine"), ("hydrogen", "oxygen"), ("carbon", "oxygen")]
            reactant1, reactant2 = random.choice(reactants)
            
            query = template[0].format(reactant1=reactant1, reactant2=reactant2)
            response = template[1].format(
                reactant1=reactant1,
                reactant2=reactant2,
                reaction_description=f"they form a {random.choice(['ionic', 'covalent'])} compound",
                equation=f"{reactant1} + {reactant2} → compound"
            )
        
        return EducationalExample(
            query=query,
            response=response,
            subject="Science",
            grade_level=grade_level,
            learning_objectives=random.sample(self.learning_objectives["Science"], 2),
            curriculum_aligned=True
        )

    def generate_distractor_example(self) -> EducationalExample:
        """Generate non-educational examples for negative sampling"""
        distractors = [
            ("What's the weather like?", "I'm an educational assistant focused on learning. Let me help you with your studies instead!"),
            ("Tell me a joke", "I'm here to help with educational content. What subject would you like to explore?"),
            ("What's your favorite movie?", "I focus on educational topics. Is there a school subject I can help you with?"),
            ("How do I cook pasta?", "I'm designed for educational assistance. Let's work on homework or learning together!"),
        ]
        
        query, response = random.choice(distractors)
        
        return EducationalExample(
            query=query,
            response=response,
            subject="Non-Educational",
            grade_level=0,
            learning_objectives=[],
            curriculum_aligned=False
        )

    def generate_dataset(self, num_examples: int = 10000) -> List[EducationalExample]:
        """Generate a complete synthetic dataset"""
        dataset = []
        
        # 70% educational content, 30% distractors
        num_educational = int(0.7 * num_examples)
        num_distractors = num_examples - num_educational
        
        logger.info(f"Generating {num_educational} educational examples...")
        for _ in range(num_educational):
            subject = random.choice(["Mathematics", "Science"])
            grade_level = random.choice(self.grade_levels)
            
            if subject == "Mathematics":
                example = self.generate_math_example(grade_level)
            else:
                example = self.generate_science_example(grade_level)
            
            dataset.append(example)
        
        logger.info(f"Generating {num_distractors} distractor examples...")
        for _ in range(num_distractors):
            dataset.append(self.generate_distractor_example())
        
        random.shuffle(dataset)
        return dataset

class EducationalDataset(Dataset):
    """PyTorch dataset for educational examples"""
    
    def __init__(self, examples: List[EducationalExample], tokenizer, max_length: int = 512):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Tokenize input and target
        input_text = f"Query: {example.query} Subject: {example.subject} Grade: {example.grade_level}"
        
        inputs = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        targets = self.tokenizer(
            example.response,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'target_ids': targets['input_ids'].squeeze(),
            'target_mask': targets['attention_mask'].squeeze(),
            'curriculum_aligned': torch.tensor(1.0 if example.curriculum_aligned else 0.0),
            'subject': example.subject,
            'grade_level': example.grade_level
        }

class EducationalChatbotModel(nn.Module):
    """Educational chatbot model based on DistilBERT"""
    
    def __init__(self, model_name: str = "distilbert-base-uncased", vocab_size: int = 30522):
        super().__init__()
        
        # Base transformer
        self.bert = DistilBertModel.from_pretrained(model_name)
        self.vocab_size = vocab_size
        
        # Educational-specific layers
        self.educational_head = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Output layers
        self.lm_head = nn.Linear(128, vocab_size)
        self.alignment_head = nn.Linear(128, 1)
        self.safety_head = nn.Linear(128, 1)
        
    def forward(self, input_ids, attention_mask, target_ids=None):
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        pooled_output = sequence_output.mean(dim=1)  # Mean pooling
        
        # Educational processing
        edu_features = self.educational_head(pooled_output)
        
        # Predictions
        lm_logits = self.lm_head(edu_features)
        alignment_score = torch.sigmoid(self.alignment_head(edu_features))
        safety_score = torch.sigmoid(self.safety_head(edu_features))
        
        outputs = {
            'logits': lm_logits,
            'alignment_score': alignment_score,
            'safety_score': safety_score
        }
        
        if target_ids is not None:
            # Calculate generation loss (simplified)
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = target_ids[..., 1:].contiguous()
            
            # Flatten for cross entropy
            flat_logits = shift_logits.view(-1, self.vocab_size)
            flat_labels = shift_labels.view(-1)
            
            gen_loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=0)
            outputs['generation_loss'] = gen_loss
        
        return outputs

class EducationalTrainer:
    """Training framework for educational chatbot"""
    
    def __init__(self, model: EducationalChatbotModel, tokenizer, device: str = 'cpu'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.to(device)
        
        # Training history
        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []
        
    def train_epoch(self, train_loader: DataLoader, optimizer, epoch: int):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(train_loader):
            # Move to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            target_ids = batch['target_ids'].to(self.device)
            curriculum_aligned = batch['curriculum_aligned'].to(self.device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(input_ids, attention_mask, target_ids)
            
            # Multi-task loss
            gen_loss = outputs.get('generation_loss', torch.tensor(0.0))
            align_loss = F.binary_cross_entropy(
                outputs['alignment_score'].squeeze(), 
                curriculum_aligned
            )
            safety_loss = outputs['safety_score'].mean()  # Encourage high safety scores
            
            total_loss_batch = 0.7 * gen_loss + 0.2 * align_loss + 0.1 * safety_loss
            
            # Backward pass
            total_loss_batch.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += total_loss_batch.item()
            num_batches += 1
            
            if batch_idx % 50 == 0:
                logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(train_loader)}, Loss: {total_loss_batch.item():.4f}")
        
        avg_loss = total_loss / num_batches
        self.train_losses.append(avg_loss)
        return avg_loss
    
    def validate(self, val_loader: DataLoader):
        """Validate the model"""
        self.model.eval()
        total_loss = 0
        correct_alignment = 0
        total_samples = 0
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                target_ids = batch['target_ids'].to(self.device)
                curriculum_aligned = batch['curriculum_aligned'].to(self.device)
                
                outputs = self.model(input_ids, attention_mask, target_ids)
                
                # Calculate losses
                gen_loss = outputs.get('generation_loss', torch.tensor(0.0))
                align_loss = F.binary_cross_entropy(
                    outputs['alignment_score'].squeeze(), 
                    curriculum_aligned
                )
                
                total_loss += gen_loss.item() + align_loss.item()
                
                # Calculate alignment accuracy
                predicted_alignment = (outputs['alignment_score'].squeeze() > 0.5).float()
                correct_alignment += (predicted_alignment == curriculum_aligned).sum().item()
                total_samples += curriculum_aligned.size(0)
        
        avg_loss = total_loss / len(val_loader)
        accuracy = correct_alignment / total_samples
        
        self.val_losses.append(avg_loss)
        self.val_accuracies.append(accuracy)
        
        return avg_loss, accuracy
    
    def train(self, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 5):
        """Complete training loop"""
        optimizer = AdamW(self.model.parameters(), lr=2e-5, weight_decay=0.01)
        
        best_val_loss = float('inf')
        patience_counter = 0
        patience = 3
        
        logger.info("Starting training...")
        
        for epoch in range(num_epochs):
            logger.info(f"\nEpoch {epoch + 1}/{num_epochs}")
            
            # Train
            train_loss = self.train_epoch(train_loader, optimizer, epoch)
            
            # Validate
            val_loss, val_accuracy = self.validate(val_loader)
            
            logger.info(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
            
            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                # Save best model
                torch.save(self.model.state_dict(), 'best_educational_model.pt')
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                logger.info("Early stopping triggered")
                break
        
        logger.info("Training completed!")

class EducationalEvaluator:
    """Comprehensive evaluation framework"""
    
    def __init__(self, model: EducationalChatbotModel, tokenizer, device: str = 'cpu'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.results = {}
        
    def generate_response(self, query: str, subject: str = "Mathematics", grade_level: int = 9, max_length: int = 100):
        """Generate response for a given query"""
        self.model.eval()
        
        input_text = f"Query: {query} Subject: {subject} Grade: {grade_level}"
        inputs = self.tokenizer(
            input_text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        )
        
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            
            # Simple greedy decoding (in real implementation would use beam search)
            predicted_ids = outputs['logits'].argmax(dim=-1)
            response = self.tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
            
            alignment_score = outputs['alignment_score'].item()
            safety_score = outputs['safety_score'].item()
        
        return {
            'response': response,
            'alignment_score': alignment_score,
            'safety_score': safety_score
        }
    
    def evaluate_educational_accuracy(self, test_examples: List[EducationalExample]):
        """Evaluate educational accuracy"""
        correct = 0
        total = 0
        
        for example in test_examples:
            if example.curriculum_aligned:  # Only evaluate educational examples
                result = self.generate_response(
                    example.query, 
                    example.subject, 
                    example.grade_level
                )
                
                # Simple heuristic: high alignment score indicates correct educational response
                if result['alignment_score'] > 0.7:
                    correct += 1
                total += 1
        
        accuracy = correct / total if total > 0 else 0
        self.results['educational_accuracy'] = accuracy
        return accuracy
    
    def evaluate_response_time(self, num_queries: int = 100):
        """Evaluate average response time"""
        times = []
        
        for _ in range(num_queries):
            query = f"What is {random.randint(1, 10)} + {random.randint(1, 10)}?"
            
            start_time = time.time()
            self.generate_response(query)
            end_time = time.time()
            
            times.append(end_time - start_time)
        
        avg_time = np.mean(times) * 1000  # Convert to milliseconds
        self.results['avg_response_time_ms'] = avg_time
        return avg_time
    
    def evaluate_memory_usage(self):
        """Estimate memory usage"""
        # Count parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        # Estimate memory (4 bytes per float32 parameter)
        memory_mb = (total_params * 4) / (1024 * 1024)
        
        self.results['memory_usage_mb'] = memory_mb
        self.results['total_parameters'] = total_params
        return memory_mb
    
    def run_comprehensive_evaluation(self, test_examples: List[EducationalExample]):
        """Run all evaluations"""
        logger.info("Running comprehensive evaluation...")
        
        # Educational accuracy
        edu_acc = self.evaluate_educational_accuracy(test_examples)
        logger.info(f"Educational Accuracy: {edu_acc:.3f}")
        
        # Response time
        resp_time = self.evaluate_response_time()
        logger.info(f"Average Response Time: {resp_time:.2f} ms")
        
        # Memory usage
        memory = self.evaluate_memory_usage()
        logger.info(f"Memory Usage: {memory:.1f} MB")
        
        # Store results
        self.results.update({
            'timestamp': datetime.now().isoformat(),
            'model_architecture': 'DistilBERT-based Educational Chatbot',
            'evaluation_date': datetime.now().strftime('%Y-%m-%d')
        })
        
        return self.results

def run_comprehensive_experiments():
    """Main experiment runner"""
    logger.info("Starting comprehensive experiments...")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    # Initialize components
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    data_generator = SyntheticEducationalDataGenerator()
    
    # Generate datasets
    logger.info("Generating synthetic dataset...")
    all_examples = data_generator.generate_dataset(num_examples=5000)  # Smaller for demo
    
    # Split data
    train_size = int(0.8 * len(all_examples))
    val_size = int(0.1 * len(all_examples))
    
    train_examples = all_examples[:train_size]
    val_examples = all_examples[train_size:train_size + val_size]
    test_examples = all_examples[train_size + val_size:]
    
    logger.info(f"Dataset split: Train={len(train_examples)}, Val={len(val_examples)}, Test={len(test_examples)}")
    
    # Create datasets and dataloaders
    train_dataset = EducationalDataset(train_examples, tokenizer)
    val_dataset = EducationalDataset(val_examples, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    # Initialize model
    model = EducationalChatbotModel()
    trainer = EducationalTrainer(model, tokenizer, device)
    
    # Train model
    trainer.train(train_loader, val_loader, num_epochs=3)  # Reduced for demo
    
    # Load best model for evaluation
    model.load_state_dict(torch.load('best_educational_model.pt'))
    
    # Evaluate
    evaluator = EducationalEvaluator(model, tokenizer, device)
    results = evaluator.run_comprehensive_evaluation(test_examples)
    
    # Save results
    results_dir = f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    os.makedirs(results_dir, exist_ok=True)
    
    with open(f"{results_dir}/evaluation_results.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(trainer.train_losses, label='Train Loss')
    plt.plot(trainer.val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Curves')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(trainer.val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{results_dir}/training_curves.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info(f"Experiments completed! Results saved in {results_dir}/")
    
    # Demonstration
    logger.info("\n=== DEMONSTRATION ===")
    demo_queries = [
        "What is the derivative of x^2?",
        "How do you solve 2x + 5 = 13?",
        "Explain photosynthesis",
        "What's your favorite color?",  # Non-educational
    ]
    
    for query in demo_queries:
        result = evaluator.generate_response(query)
        logger.info(f"Query: {query}")
        logger.info(f"Response: {result['response']}")
        logger.info(f"Educational Alignment: {result['alignment_score']:.3f}")
        logger.info(f"Safety Score: {result['safety_score']:.3f}")
        logger.info("-" * 50)
    
    return results_dir, results

if __name__ == "__main__":
    results_dir, results = run_comprehensive_experiments()
    print(f"Experiments completed successfully! Results saved in: {results_dir}")
    print("Key Results:")
    print(f"- Educational Accuracy: {results.get('educational_accuracy', 'N/A'):.3f}")
    print(f"- Average Response Time: {results.get('avg_response_time_ms', 'N/A'):.2f} ms")
    print(f"- Memory Usage: {results.get('memory_usage_mb', 'N/A'):.1f} MB")