import os
import json
from typing import Dict, List
from pathlib import Path

from collections import defaultdict
from statistics import mean
from itertools import chain
try:
    from rich.console import Console
    from rich.table import Table
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    not_installed = False
except ImportError:
    not_installed = True

from .acceptance_rate import AcceptanceRate


class SpecBench(AcceptanceRate):
    def __init__(self, requests):
        super().__init__()
        if not_installed:
            raise ImportError("Please install rich, matplotlib, seaborn, and pandas to use the SpecBench metric")
        self.requests = requests
        
    def process_final(self, text_outputs):
        lengths = {}
        self.out["Request_AR"] = {}
        for request_id, request in enumerate(self.requests):
            turns = self.prompt_ar[request_id].values()
            assert len(turns) == len(request.turns), f"Number of turns {len(turns)} does not match number of turns in request {len(request.turns)}"
            self.out["Request_AR"][request.question_id] = mean(list(chain(*turns)))
            for turn in turns:
                self._get_lengths(turn, lengths)
            print(request.category, self.out["Request_AR"][request.question_id])
        per_category = defaultdict(list)
        for request in self.requests:
            per_category[request.category].append(self.out["Request_AR"][request.question_id])
        self.out["Category_AR"] = {}
        for category_name, category_ar in per_category.items():
            if len(category_ar) > 0:
                category_ar = mean(category_ar)
                self.out["Category_AR"][category_name] = category_ar
        average_ar = mean(self.out["Request_AR"].values())
        self.out["Average_AR"] = average_ar
        self._process_lengths(lengths)
        self.write()
        self._format_write_output(text_outputs)
        self._pretty_print_results()
        self._dump_results()
        self._create_visualizations(text_outputs)

    def _format_write_output(self, outputs):
        with open(os.path.join(self.directory, "specbench_responses.jsonl"), 'w') as outfile:
            for i, messages in enumerate(outputs):
                out_line = {}
                out_line['question_id'] = self.requests[i].question_id
                out_line['category'] = self.requests[i].category
                q_turns = [c['content'] for c in messages if c['role'] == "user"]
                a_turns = [c['content'] for c in messages if c['role'] == "assistant"]
                out_line['turns'] = q_turns
                out_line['choices'] = [{'index': 0, "turns": a_turns}]
                json.dump(out_line, outfile)
                outfile.write('\n')

    def _pretty_print_results(self):
        # Create and display results table
        console = Console()
        table = Table(title="Acceptance Rate Results", show_header=True, header_style="bold magenta")
        table.add_column("Category", style="cyan", no_wrap=True)
        table.add_column("Average AR", justify="right", style="green")
        
        # Add category rows
        for category_name, category_ar in sorted(self.out["Category_AR"].items()):
            table.add_row(category_name, f"{category_ar:.4f}")
        
        # Add separator and summary row
        table.add_section()
        table.add_row("[bold]Overall Average[/bold]", f"[bold]{self.out['Average_AR']:.4f}[/bold]")
        
        console.print(table)
    
    def _dump_results(self):
        with open(os.path.join(self.directory, "specbench_results.json"), 'w') as outfile:
            json.dump(self.out, outfile, indent=4)

    def _create_visualizations(self, text_outputs: List[List[Dict[str, str]]], title: str = "Speculative Decoding Acceptance Rate Analysis"):
        """
        Create professional plots for acceptance rates.
        Completely generated by Cursor.
        """

        # Set style
        plt.style.use('seaborn-v0_8')
        
        df_clean = pd.DataFrame.from_dict({"question_id": list(self.out["Request_AR"].keys()), 
                                           "acceptance_rate": list(self.out["Request_AR"].values()),
                                           "category": [request.category for request in self.requests],
                                           "response_length": [mean([len(c['content']) for c in messages if c['role'] == "assistant"]) for messages in text_outputs]})

        if len(df_clean) == 0:
            print("Warning: No successful results to plot")
            return
        
        # 1. Acceptance rate by category
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle(title, fontsize=16, fontweight='bold')
        
        # Plot 1: Acceptance rate by category
        ax1 = axes[0]
        category_stats = df_clean.groupby('category').agg({
            'acceptance_rate': ['mean', 'std'],
            'question_id': 'count'
        }).round(3)
        
        categories = category_stats.index.tolist()
        means = category_stats[('acceptance_rate', 'mean')].values
        stds = category_stats[('acceptance_rate', 'std')].values
        counts = category_stats[('question_id', 'count')].values
        
        bars = ax1.bar(range(len(categories)), means, yerr=stds, capsize=5, alpha=0.8)
        ax1.set_xlabel('Category')
        ax1.set_ylabel('Acceptance Rate')
        ax1.set_title('Acceptance Rate by Category')
        ax1.set_xticks(range(len(categories)))
        ax1.set_xticklabels(categories, rotation=45, ha='right')
        
        # Add count labels on bars
        for i, (bar, count) in enumerate(zip(bars, counts)):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'n={count}', ha='center', va='bottom', fontsize=8)
        
        # Plot 2: Acceptance rate vs response length
        ax2 = axes[1]
        # Bin response lengths
        df_clean['response_length_bin'] = pd.cut(df_clean['response_length'], 
                                            bins=[0, 100, 300, 500, 1000, float('inf')],
                                            labels=['0-100', '100-300', '300-500', '500-1000', '1000+'])
        
        length_stats = df_clean.groupby('response_length_bin').agg({
            'acceptance_rate': ['mean', 'std'],
            'question_id': 'count'
        }).round(3)
        
        length_bins = length_stats.index.tolist()
        length_means = length_stats[('acceptance_rate', 'mean')].values
        length_stds = length_stats[('acceptance_rate', 'std')].values
        length_counts = length_stats[('question_id', 'count')].values
        
        bars2 = ax2.bar(range(len(length_bins)), length_means, yerr=length_stds, capsize=5, alpha=0.8)
        ax2.set_xlabel('Response Length (characters)')
        ax2.set_ylabel('Acceptance Rate')
        ax2.set_title('Acceptance Rate by Response Length')
        ax2.set_xticks(range(len(length_bins)))
        ax2.set_xticklabels(length_bins)
        
        for i, (bar, count) in enumerate(zip(bars2, length_counts)):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'n={count}', ha='center', va='bottom', fontsize=8)
        
        # Plot 3: Distribution of acceptance rates
        ax3 = axes[2]
        ax3.hist(df_clean['acceptance_rate'], bins=20, alpha=0.7, edgecolor='black')
        ax3.axvline(df_clean['acceptance_rate'].mean(), color='red', linestyle='--', 
                    label=f'Mean: {df_clean["acceptance_rate"].mean():.3f}')
        ax3.set_xlabel('Acceptance Rate')
        ax3.set_ylabel('Frequency')
        ax3.set_title('Distribution of Acceptance Rates')
        ax3.legend()
        
        plt.tight_layout()
        plot_path = Path(self.directory) / "acceptance_rate_analysis.png"
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Plots saved to {plot_path}")
