import torch
import numpy as np
import random
import json
import pickle
from typing import Dict, List, Optional
from pathlib import Path

def set_random_seeds(seed: int = 42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def save_results(results: Dict, filepath: str):
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    if filepath.suffix == '.json':
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2, default=str)
    elif filepath.suffix == '.pkl':
        with open(filepath, 'wb') as f:
            pickle.dump(results, f)
    else:
        raise ValueError("Unsupported file format. Use .json or .pkl")

def load_results(filepath: str) -> Dict:
    filepath = Path(filepath)
    
    if not filepath.exists():
        raise FileNotFoundError(f"Results file not found: {filepath}")
    
    if filepath.suffix == '.json':
        with open(filepath, 'r') as f:
            return json.load(f)
    elif filepath.suffix == '.pkl':
        with open(filepath, 'rb') as f:
            return pickle.load(f)
    else:
        raise ValueError("Unsupported file format. Use .json or .pkl")

def compute_statistical_significance(results1: List[float], results2: List[float], 
                                   alpha: float = 0.05) -> Dict[str, float]:
    from scipy import stats
    
    if len(results1) != len(results2):
        raise ValueError("Result lists must have the same length")
    
    t_stat, p_value = stats.ttest_rel(results1, results2)
    
    return {
        't_statistic': t_stat,
        'p_value': p_value,
        'significant': p_value < alpha,
        'effect_size': np.mean(results1) - np.mean(results2)
    }

def format_metrics_table(metrics: Dict[str, float], precision: int = 3) -> str:
    table = "Metrics Summary:\n"
    table += "-" * 40 + "\n"
    
    for metric, value in metrics.items():
        formatted_name = metric.replace('_', ' ').title()
        if metric in ['factual_accuracy', 'hallucination_rate', 'calibration_score']:
            table += f"{formatted_name:<20}: {value:.{precision}f}\n"
        else:
            table += f"{formatted_name:<20}: {value:.{precision}f}\n"
    
    return table

def validate_config(config) -> bool:
    required_attrs = ['model_name', 'tokenizer_name']
    
    for attr in required_attrs:
        if not hasattr(config, attr):
            print(f"Missing required attribute: {attr}")
            return False
    
    if hasattr(config, 'learning_rate') and config.learning_rate <= 0:
        print("Learning rate must be positive")
        return False
    
    return True

def estimate_memory_usage(model_name: str) -> str:
    model_sizes = {
        'distilgpt2': '0.3GB',
        'gpt2': '0.5GB',
        'gpt2-medium': '1.5GB',
        'gpt2-large': '3GB',
        'meta-llama/Llama-2-7b-hf': '14GB',
        'meta-llama/Llama-2-13b-hf': '26GB'
    }
    
    return model_sizes.get(model_name, 'Unknown')

def check_system_requirements():
    requirements = {
        'torch_available': torch.cuda.is_available(),
        'cuda_devices': torch.cuda.device_count() if torch.cuda.is_available() else 0,
        'total_memory': torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0
    }
    
    print("System Requirements Check:")
    print(f"CUDA Available: {requirements['torch_available']}")
    print(f"CUDA Devices: {requirements['cuda_devices']}")
    if requirements['total_memory'] > 0:
        print(f"Total GPU Memory: {requirements['total_memory']:.1f} GB")
    
    return requirements

def cleanup_gpu_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("GPU memory cache cleared")

def log_experiment_info(model_config, training_config, method_name: str):
    print(f"\nExperiment Configuration:")
    print(f"Model: {model_config.model_name}")
    print(f"Method: {method_name}")
    print(f"Learning Rate: {training_config.learning_rate}")
    print(f"Batch Size: {training_config.batch_size}")
    print(f"Max Tokens: {training_config.max_new_tokens}")
    print(f"Device: {model_config.device}")
    print("-" * 40)

class ExperimentLogger:
    def __init__(self, log_dir: str = "logs"):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        self.current_log = []
    
    def log(self, message: str, level: str = "INFO"):
        import datetime
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_entry = f"[{timestamp}] {level}: {message}"
        print(log_entry)
        self.current_log.append(log_entry)
    
    def save_log(self, filename: str = None):
        if filename is None:
            import datetime
            filename = f"experiment_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
        
        log_path = self.log_dir / filename
        with open(log_path, 'w') as f:
            f.write('\n'.join(self.current_log))
        print(f"Log saved to {log_path}")
    
    def clear_log(self):
        self.current_log = []

def create_experiment_summary(results: Dict) -> Dict:
    summary = {}
    
    for model_name, model_results in results.items():
        summary[model_name] = {}
        
        for method_name, history in model_results.items():
            if history['factual_accuracy']:
                summary[model_name][method_name] = {
                    'final_factual_accuracy': history['factual_accuracy'][-1],
                    'final_hallucination_rate': history['hallucination_rate'][-1],
                    'final_coherence_score': history['coherence_score'][-1],
                    'best_factual_accuracy': max(history['factual_accuracy']),
                    'convergence_epoch': len(history['factual_accuracy']),
                    'total_reward': history['cumulative_reward'][-1] if history['cumulative_reward'] else 0
                }
            else:
                summary[model_name][method_name] = {
                    'status': 'failed'
                }
    
    return summary
        