import json
import numpy as np
from typing import Dict, List, Any, Set
import networkx as nx
from utils import load_config
import math
import os
import glob
from collections import defaultdict

# Alpha parameter for debugging coherence variance calculation
ALPHA = 100.0  # Can adjust this value to test different weights


class GraphEvaluator:
    def __init__(self, config: Dict[str, Any] = None):
        if config is None:
            config = load_config("src/config.yaml")
        self.config = config
        self.G = None  # Cache for graph
        self.questions_data = None  # Cache for questions data
    
    def load_graph(self, graph_file: str) -> Dict[str, Any]:
        with open(graph_file, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    def load_questions_data(self, questions_file: str) -> List[Dict[str, Any]]:
        """Load questions data and extract category information"""
        with open(questions_file, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    def get_categories_mapping(self, questions_data: List[Dict[str, Any]]) -> Dict[str, Set[int]]:
        """Build category to node ID mapping based on questions data"""
        category_nodes = defaultdict(set)
        
        for question in questions_data:
            question_id = question.get('id')
            category = question.get('category', 'unknown')
            if question_id is not None:
                category_nodes[category].add(question_id)
        
        return category_nodes
    
    def build_subgraph_by_category(self, G: nx.Graph, category_nodes: Set[int]) -> nx.Graph:
        """Build subgraph based on category nodes"""
        # Only keep nodes that are in the category
        valid_nodes = [node for node in G.nodes() if node in category_nodes]
        subgraph = G.subgraph(valid_nodes).copy()
        return subgraph
    
    def load_model_results(self, results_file: str) -> Dict[int, int]:
        """Load model evaluation results"""
        with open(results_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # Extract answer results from the results field
        results = {}
        for item in data.get('results', []):
            question_id = item.get('id')
            is_correct = item.get('is_correct', 0)
            results[question_id] = is_correct
        
        model_name = data.get('model_name', 'Unknown')
        accuracy = data.get('accuracy_percentage', 'Unknown')
        
        return results, model_name, accuracy
    
    def build_graph(self, graph_data: Dict[str, Any]) -> nx.Graph:
        """Build NetworkX graph"""
        G = nx.Graph()
        for edge in graph_data['edges']:
            if edge['weight'] > 0.4:
                G.add_edge(edge['source'], edge['target'], weight=edge['weight'])
        return G
    
    def compute_neighborhood_coherence(self, node: int, G: nx.Graph, results: Dict[int, int]) -> float:
        """Compute neighborhood coherence for a node
        
        Coherence_M(v) = Σ_{u∈N(v)} w(v,u) * Res_M(u) / Σ_{u∈N(v)} w(v,u)
        """
        neighbors = list(G.neighbors(node))
        
        if not neighbors:
            return 0.0
        
        weighted_sum = 0.0
        weight_sum = 0.0
        
        for neighbor in neighbors:
            weight = G[node][neighbor]['weight']
            neighbor_result = results.get(neighbor, 0)
            
            weighted_sum += weight * neighbor_result
            weight_sum += weight
        
        return weighted_sum / weight_sum if weight_sum > 0 else 0.0
    
    def compute_coherence_variance(self, coherence_scores: List[float], alpha: float = ALPHA) -> Dict[str, float]:
        """Compute weighted variance of coherence scores"""
        if not coherence_scores:
            return {
                'coherence_variance': 0.0,
                'weighted_coherence_variance': 0.0,
                'alpha': alpha
            }

        coherence_array = np.array(coherence_scores)
        
        # Original variance
        coherence_variance = np.var(coherence_array)
        
        # Weighted variance (each coherence score multiplied by alpha)
        weighted_coherence_array = coherence_array * alpha
        weighted_coherence_variance = np.var(weighted_coherence_array)

        return {
            'coherence_variance': coherence_variance,
            'KBS': weighted_coherence_variance,
            'alpha': alpha
        }
    
    def evaluate_single_model_on_subgraph(self, results: Dict[int, int], model_name: str, 
                                        accuracy: float, G: nx.Graph, alpha: float = ALPHA) -> Dict[str, Any]:
        """Evaluate a single model on a specified subgraph"""
        # Calculate coherence score for each node
        coherence_scores = []
        total_score = 0.0
        valid_nodes = 0
        correct_answers = 0
        
        for node in G.nodes():
            if node in results:
                node_result = results[node]
                coherence = self.compute_neighborhood_coherence(node, G, results)
                
                coherence_scores.append(coherence)
                contribution = node_result * coherence
                total_score += contribution
                valid_nodes += 1
                
                # Count correct answers
                if node_result == 1:
                    correct_answers += 1
        
        # Calculate GCS
        gcs = total_score / valid_nodes if valid_nodes > 0 else 0.0
        
        # Calculate accuracy for this subgraph
        subgraph_accuracy = (correct_answers / valid_nodes * 100) if valid_nodes > 0 else 0.0
        
        # Calculate variance of coherence scores
        variance_metrics = self.compute_coherence_variance(coherence_scores, alpha)
        
        # Calculate statistics of coherence scores
        coherence_stats = {
            'mean': np.mean(coherence_scores) if coherence_scores else 0.0,
            'std': np.std(coherence_scores) if coherence_scores else 0.0,
            'min': np.min(coherence_scores) if coherence_scores else 0.0,
            'max': np.max(coherence_scores) if coherence_scores else 0.0,
            'median': np.median(coherence_scores) if coherence_scores else 0.0
        }
        
        return {
            'model_name': model_name,
            'accuracy': subgraph_accuracy,  # Accuracy for this subgraph
            'GCS': gcs,
            'coherence_variance': variance_metrics['coherence_variance'],
            'KBS': variance_metrics['KBS'],
            'coherence_stats': coherence_stats,
            'variance_details': variance_metrics,
            'valid_nodes': valid_nodes,
            'total_nodes': len(G.nodes()),
            'total_edges': len(G.edges())
        }
    
    def evaluate_single_model(self, results: Dict[int, int], model_name: str, 
                            accuracy: float, alpha: float = ALPHA) -> Dict[str, Any]:
        """Evaluate a single model (compatible with original interface)"""
        metrics = self.evaluate_single_model_on_subgraph(results, model_name, accuracy, self.G, alpha)
        # For full graph evaluation, maintain the original overall accuracy
        metrics['accuracy'] = accuracy
        return metrics
    
    def evaluate_batch_by_category(self, graph_file: str, questions_file: str, results_path: str, 
                                 alpha: float = ALPHA) -> Dict[str, Any]:
        """Batch evaluate multiple models by category"""
        # Load graph and questions data only once
        print("Loading graph...")
        graph_data = self.load_graph(graph_file)
        self.G = self.build_graph(graph_data)
        
        print("Loading questions data...")
        self.questions_data = self.load_questions_data(questions_file)
        
        # Build category mapping
        category_mapping = self.get_categories_mapping(self.questions_data)
        print(f"Found categories: {list(category_mapping.keys())}")
        
        # Build subgraph for each category
        category_subgraphs = {}
        for category, nodes in category_mapping.items():
            subgraph = self.build_subgraph_by_category(self.G, nodes)
            category_subgraphs[category] = subgraph
            print(f"Category '{category}': {len(subgraph.nodes())} nodes, {len(subgraph.edges())} edges")
        
        print(f"Full Graph: {len(self.G.nodes())} nodes, {len(self.G.edges())} edges")
        
        # Get all json files
        if os.path.isfile(results_path):
            json_files = [results_path]
        else:
            json_files = glob.glob(os.path.join(results_path, "*.json"))
        
        print(f"Found {len(json_files)} result files")
        
        # Batch evaluation - simplified output format
        model_results = {}
        
        for json_file in sorted(json_files):
            try:
                print(f"Processing {os.path.basename(json_file)}...")
                results, model_name, accuracy = self.load_model_results(json_file)
                
                # Evaluate the entire graph
                full_metrics = self.evaluate_single_model_on_subgraph(results, model_name, accuracy, self.G, alpha)
                
                # Build results for this model
                model_result = {
                    'model_name': model_name,
                    'accuracy': accuracy, 
                    'full_graph': {
                        'accuracy': accuracy,  # Full graph accuracy (same as overall accuracy)
                        'GCS': full_metrics['GCS'] * 100,  # Convert to percentage
                        'KBS': full_metrics['KBS']
                    },
                    'categories': {}
                }
                
                # Evaluate each category subgraph
                for category, subgraph in category_subgraphs.items():
                    category_metrics = self.evaluate_single_model_on_subgraph(results, model_name, accuracy, subgraph, alpha)
                    model_result['categories'][category] = {
                        'accuracy': category_metrics['accuracy'],  # Accuracy for this category
                        'GCS': category_metrics['GCS'] * 100,  # Convert to percentage
                        'KBS': category_metrics['KBS']
                    }
                
                model_results[os.path.basename(json_file)] = model_result
                
                print(f"  Full Graph - Accuracy: {accuracy:.2f}%, GCS: {full_metrics['GCS'] * 100:.2f}%")
                for category in category_mapping.keys():
                    cat_data = model_result['categories'][category]
                    print(f"  {category} - Accuracy: {cat_data['accuracy']:.2f}%, GCS: {cat_data['GCS']:.2f}%")
                
            except Exception as e:
                print(f"Error processing {json_file}: {e}")
                continue
        
        return {
            'alpha': alpha,
            'graph_info': {
                'total_nodes': len(self.G.nodes()),
                'total_edges': len(self.G.edges()),
                'categories': {cat: {'nodes': len(subgraph.nodes()), 'edges': len(subgraph.edges())} 
                              for cat, subgraph in category_subgraphs.items()}
            },
            'models': model_results
        }
    
    def evaluate_batch(self, graph_file: str, results_path: str, alpha: float = ALPHA) -> Dict[str, Any]:
        """Batch evaluate multiple models (compatible with original interface, only evaluate full graph)"""
        # Load graph only once
        print("Loading graph...")
        graph_data = self.load_graph(graph_file)
        self.G = self.build_graph(graph_data)
        
        print(f"Graph: {len(self.G.nodes())} nodes, {len(self.G.edges())} edges")
        
        # Get all json files
        if os.path.isfile(results_path):
            # Single file
            json_files = [results_path]
        else:
            # Directory
            json_files = glob.glob(os.path.join(results_path, "*.json"))
        
        print(f"Found {len(json_files)} result files")
        
        # Batch evaluation
        model_results = {}
        
        for json_file in sorted(json_files):
            try:
                print(f"Processing {os.path.basename(json_file)}...")
                results, model_name, accuracy = self.load_model_results(json_file)
                
                model_metrics = self.evaluate_single_model(results, model_name, accuracy, alpha)
                print(f"Accuracy: {model_metrics['accuracy']:.2f}%, GCS: {model_metrics['GCS'] * 100:.2f}%")
                model_results[os.path.basename(json_file)] = model_metrics
                
            except Exception as e:
                print(f"Error processing {json_file}: {e}")
                continue
        
        # Summary statistics
        if model_results:
            all_gcs = [m['GCS'] for m in model_results.values()]
            all_variances = [m['coherence_variance'] for m in model_results.values()]
            all_weighted_variances = [m['KBS'] for m in model_results.values()]
            
            summary = {
                'num_models': len(model_results),
                'gcs_stats': {
                    'mean': np.mean(all_gcs),
                    'std': np.std(all_gcs),
                    'min': np.min(all_gcs),
                    'max': np.max(all_gcs)
                },
                'variance_stats': {
                    'mean': np.mean(all_variances),
                    'std': np.std(all_variances),
                    'min': np.min(all_variances),
                    'max': np.max(all_variances)
                },
                'KBS_stats': {
                    'mean': np.mean(all_weighted_variances),
                    'std': np.std(all_weighted_variances),
                    'min': np.min(all_weighted_variances),
                    'max': np.max(all_weighted_variances)
                }
            }
        else:
            summary = {'num_models': 0}
        
        return {
            'graph_info': {
                'total_nodes': len(self.G.nodes()),
                'edges': len(self.G.edges()),
                'alpha': alpha
            },
            'summary': summary,
            'model_results': model_results
        }


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Compute GCS and Coherence Variance")
    parser.add_argument("--graph", "-g", required=True, help="Graph JSON file")
    parser.add_argument("--results", "-r", required=True, help="Model results JSON file or directory")
    parser.add_argument("--questions", "-q", help="Questions JSON file (for category-based evaluation)")
    parser.add_argument("--output", "-o", required=True, help="Output JSON file")
    parser.add_argument("--alpha", type=float, default=ALPHA, help=f"Alpha weight for coherence variance calculation (default: {ALPHA})")
    
    args = parser.parse_args()
    
    evaluator = GraphEvaluator()
    
    # Choose evaluation method based on whether questions file is provided
    if args.questions:
        print("Running category-based evaluation...")
        all_metrics = evaluator.evaluate_batch_by_category(args.graph, args.questions, args.results, args.alpha)
        
        # Output results to console
        print(f"\n{'='*60}")
        print(f"CATEGORY-BASED BATCH EVALUATION RESULTS")
        print(f"{'='*60}")
        print(f"Alpha parameter: {all_metrics['alpha']}")
        print(f"Models evaluated: {len(all_metrics['models'])}")
        
        # Output results for each model
        for model_file, model_data in all_metrics['models'].items():
            print(f"\n{'='*15} {model_data['model_name']} {'='*15}")
            print(f"File: {model_file}")
            print(f"Accuracy: {model_data['accuracy']}")
            
            # Full graph results
            full_graph = model_data['full_graph']
            print(f"\nFull Graph:")
            print(f"  Accuracy: {full_graph['accuracy']:.2f}%")
            print(f"  GCS: {full_graph['GCS']:.2f}%")
            print(f"  KBS: {full_graph['KBS']:.4f}")
            
            # Results for each category
            for category, cat_data in model_data['categories'].items():
                cat_info = all_metrics['graph_info']['categories'][category]
                print(f"\nCategory '{category}' ({cat_info['nodes']} nodes, {cat_info['edges']} edges):")
                print(f"  Accuracy: {cat_data['accuracy']:.2f}%")
                print(f"  GCS: {cat_data['GCS']:.2f}%")
                print(f"  KBS: {cat_data['KBS']:.4f}")
        
        print(f"\nGraph Info:")
        print(f"  Total Nodes: {all_metrics['graph_info']['total_nodes']}")
        print(f"  Total Edges: {all_metrics['graph_info']['total_edges']}")
        print(f"  Alpha: {all_metrics['alpha']}")
        
    else:
        print("Running full-graph evaluation...")
        all_metrics = evaluator.evaluate_batch(args.graph, args.results, args.alpha)
        
        # Output results to console
        print(f"\n{'='*50}")
        print(f"BATCH EVALUATION RESULTS")
        print(f"{'='*50}")
        print(f"Models evaluated: {all_metrics['summary']['num_models']}")
        print(f"Alpha parameter: {args.alpha}")
        
        if all_metrics['summary']['num_models'] > 0:
            print(f"\nGCS Summary:")
            gcs_stats = all_metrics['summary']['gcs_stats']
            print(f"  Mean: {gcs_stats['mean'] * 100:.2f}%")
            print(f"  Std:  {gcs_stats['std'] * 100:.2f}%")
            print(f"  Range: [{gcs_stats['min'] * 100:.2f}%, {gcs_stats['max'] * 100:.2f}%]")
            
            print(f"\nCoherence Variance Summary:")
            variance_stats = all_metrics['summary']['variance_stats']
            print(f"  Mean: {variance_stats['mean']:.4f}")
            print(f"  Std:  {variance_stats['std']:.4f}")
            print(f"  Range: [{variance_stats['min']:.4f}, {variance_stats['max']:.4f}]")
            
            print(f"\nKBS Summary (alpha={args.alpha}):")
            weighted_variance_stats = all_metrics['summary']['KBS_stats']
            print(f"  Mean: {weighted_variance_stats['mean']:.4f}")
            print(f"  Std:  {weighted_variance_stats['std']:.4f}")
            print(f"  Range: [{weighted_variance_stats['min']:.4f}, {weighted_variance_stats['max']:.4f}]")
            
            print(f"\nTop 5 Models by GCS:")
            sorted_models = sorted(all_metrics['model_results'].items(), 
                                 key=lambda x: x[1]['GCS'], reverse=True)[:5]
            for i, (filename, metrics) in enumerate(sorted_models, 1):
                print(f"  {i}. {filename}: Accuracy={metrics['accuracy']:.2f}%, GCS={metrics['GCS'] * 100:.2f}%, Variance={metrics['coherence_variance']:.4f}")
        
        print(f"\nGraph Info:")
        print(f"  Nodes: {all_metrics['graph_info']['total_nodes']}")
        print(f"  Edges: {all_metrics['graph_info']['edges']}")
        print(f"  Alpha: {all_metrics['graph_info']['alpha']}")
    
    # Save results
    with open(args.output, 'w', encoding='utf-8') as f:
        json.dump(all_metrics, f, indent=2)
    print(f"\nResults saved to: {args.output}")


if __name__ == "__main__":
    main() 