import pandas as pd
import os
import html

root = "python/followup/judge_claude-sonnet-4/"

def analyze_failures(model1, res1, model2, res2, output):
    pass_res = ["['1']"]
    fail_res = ["['0']", "['-1']"]
    res_dict = [fail_res, pass_res]
    # Read both CSV files
    print("Reading CSV files...")
    model1_df = pd.read_csv(root+ model1)
    model2_df = pd.read_csv(root + model2)
    
    # Check result distributions
    print(f"Model1 result distribution:\n{model1_df['result'].value_counts()}")
    print(f"Model2 result distribution:\n{model2_df['result'].value_counts()}")
    
    # Merge on id to compare results
    print("Merging dataframes...")
    merged_df = pd.merge(model1_df, model2_df, on='id', suffixes=('_m1', '_m2'))
    print(f"Merged dataframe shape: {merged_df.shape}")
    
    filtered = merged_df[(merged_df['result_m1'].isin(res_dict[res1])) & (merged_df['result_m2'].isin(res_dict[res2]))]
    
    print(f"Found {len(filtered)} samples where {model1} gets {res1} but {model2} gets {res2}")
    
    # Create failures directory
    os.makedirs(f'arena-{output}', exist_ok=True)
    
    # Create HTML files for each failure
    for idx, row in filtered.iterrows():
        sample_id = row['id']
        category = row['category_m2']  # Should be same for both
        
        # Process text fields - handle escaped newlines and format properly
        def format_text(text):
            if pd.isna(text):
                return "N/A"
            text = str(text)
            # Handle escaped newlines in the data
            text = text.replace('\\n', '\n')
            # Remove leading/trailing quotes if present
            if text.startswith("['") and text.endswith("']"):
                text = text[2:-2]
            elif text.startswith('[') and text.endswith(']'):
                text = text[1:-1]
            # Handle escaped quotes
            text = text.replace("\\'", "'").replace('\\"', '"')
            return html.escape(text)
        
        instruction_string = format_text(row['instruction_string_m2'])
        prev_ans = format_text(row['prev_ans_m2'])
        if_answer_m2 = format_text(row['IF_answers_m2'])
        if_answer_m1 = format_text(row['IF_answers_m1'])
        
        m2_cot = format_text(row['verify_cot_m2']) if 'verify_cot_m2' in row else ""
        m1_cot = format_text(row['verify_cot_m1']) if 'verify_cot_m1' in row else ""
        
        # Create HTML content
        html_content = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Pairwise Analysis - {sample_id}</title>
    <style>
        body {{
            font-family: Arial, sans-serif;
            margin: 20px;
            line-height: 1.6;
        }}
        .header {{
            background-color: #f4f4f4;
            padding: 15px;
            border-radius: 5px;
            margin-bottom: 20px;
        }}
        .section {{
            margin-bottom: 30px;
            border: 1px solid #ddd;
            border-radius: 5px;
            padding: 15px;
        }}
        .section h3 {{
            margin-top: 0;
            color: #333;
        }}
        .comparison {{
            display: flex;
            gap: 20px;
        }}
        .model-result {{
            flex: 1;
            border: 1px solid #ccc;
            border-radius: 5px;
            padding: 15px;
        }}
        .m2 {{
            background-color: #fff5f5;
            border-left: 4px solid #e53e3e;
        }}
        .m1 {{
            background-color: #f0fff4;
            border-left: 4px solid #38a169;
        }}
        pre {{
            background-color: #f8f8f8;
            padding: 10px;
            border-radius: 3px;
            overflow-x: auto;
            white-space: pre-wrap;
        }}
        .id {{
            font-family: monospace;
            background-color: #e2e8f0;
            padding: 2px 6px;
            border-radius: 3px;
        }}
    </style>
</head>
<body>
    <div class="header">
        <h1>Pairwise Analysis</h1>
        <p><strong>ID:</strong> <span class="id">{sample_id}</span></p>
        <p><strong>Category:</strong> {category}</p>
        <p><strong>Result:</strong> {model1} gets {res1}, {model2} gets {res2}</p>
    </div>
    
    <div class="section">
        <h3>Instruction String</h3>
        <pre>{instruction_string}</pre>
    </div>
    
    <div class="section">
        <h3>Previous Answer</h3>
        <pre>{prev_ans}</pre>
    </div>
    
    <div class="section">
        <h3>Model Responses Comparison</h3>
        <div class="comparison">
            <div class="model-result m2">
                <h4>{model1} (Failed - Result: 0)</h4>
                <pre>{if_answer_m2}</pre>
                <p>{m1_cot}</p>
            </div>
            <div class="model-result m1">
                <h4>{model2} (Succeeded - Result: 1)</h4>
                <pre>{if_answer_m1}</pre>
                <p>{m2_cot}</p>
            </div>
        </div>
    </div>
</body>
</html>"""
        
        # Write HTML file
        filename = f'arena-{output}' + f"/{sample_id}.html"
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(html_content)
        
        print(f"Created {filename}")
    
    # Create index.html file
    if len(filtered) > 0:
        create_index_html(filtered)
    
    return len(filtered)

def create_index_html(samples):
    """Create an index.html file listing all filtered samples"""
    
    index_html = """<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Analysis Index</title>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
            line-height: 1.6;
        }
        .header {
            background-color: #f4f4f4;
            padding: 20px;
            border-radius: 5px;
            margin-bottom: 30px;
            text-align: center;
        }
        .stats {
            background-color: #e8f4fd;
            padding: 15px;
            border-radius: 5px;
            margin-bottom: 20px;
        }
        table {
            width: 100%;
            border-collapse: collapse;
            margin-top: 20px;
        }
        th, td {
            padding: 12px;
            text-align: left;
            border-bottom: 1px solid #ddd;
        }
        th {
            background-color: #f8f9fa;
            font-weight: bold;
        }
        tr:hover {
            background-color: #f5f5f5;
        }
        .id-cell {
            font-family: monospace;
            font-size: 0.9em;
        }
        .category-cell {
            background-color: #e3f2fd;
            border-radius: 3px;
            padding: 4px 8px;
            display: inline-block;
        }
        a {
            color: #1976d2;
            text-decoration: none;
        }
        a:hover {
            text-decoration: underline;
        }
        .summary {
            display: grid;
            grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
            gap: 15px;
            margin-bottom: 20px;
        }
        .summary-card {
            background-color: #fff3e0;
            padding: 15px;
            border-radius: 5px;
            border-left: 4px solid #ff9800;
        }
        .summary-card h4 {
            margin: 0 0 10px 0;
            color: #e65100;
        }
    </style>
</head>"""

    index_html += f"""<body>
    <div class="header">
        <h1>🔍 Pairwise Analysis Index</h1>
        <p> {model1} gets {res1}, {model2} gets {res2} </p>
    </div>
    
    <div class="stats">
        <h3>📊 Summary</h3>
        <div class="summary">"""
    
    # Calculate category statistics
    category_counts = samples['category_m2'].value_counts()
    total_filtered_samples = len(samples)
    
    index_html += f"""
            <div class="summary-card">
                <h4>Filtered Samples</h4>
                <p style="font-size: 1.5em; margin: 0; font-weight: bold;">{total_filtered_samples}</p>
            </div>
            <div class="summary-card">
                <h4>Categories Affected</h4>
                <p style="font-size: 1.5em; margin: 0; font-weight: bold;">{len(category_counts)}</p>
            </div>
            <div class="summary-card">
                <h4>Most Common Category</h4>
                <p style="margin: 0;"><strong>{category_counts.index[0]}</strong><br>({category_counts.iloc[0]} samples)</p>
            </div>
        </div>
    </div>
    
    <h3>📋 All Samples</h3>
    <table>
        <thead>
            <tr>
                <th>#</th>
                <th>ID</th>
                <th>Category</th>
                <th>Link</th>
            </tr>
        </thead>
        <tbody>"""
    
    # Add table rows for each sample
    for i, (idx, row) in enumerate(samples.iterrows(), 1):
        sample_id = row['id']
        category = row['category_m2']
        filename = f"{sample_id}.html"
        
        index_html += f"""
            <tr>
                <td>{i}</td>
                <td class="id-cell">{sample_id}</td>
                <td><span class="category-cell">{category}</span></td>
                <td><a href="{filename}" target="_blank">View Analysis →</a></td>
            </tr>"""
    
    index_html += """
        </tbody>
    </table>
    
    <div style="margin-top: 40px; padding: 20px; background-color: #f8f9fa; border-radius: 5px;">
        <h4>📝 Category Breakdown</h4>
        <ul>"""
    
    # Add category breakdown
    for category, count in category_counts.items():
        percentage = (count / total_filtered_samples) * 100
        index_html += f"<li><strong>{category}</strong>: {count} samples ({percentage:.1f}%)</li>"
    
    index_html += """
        </ul>
    </div>
</body>
</html>"""
    
    # Write index.html file
    with open(f'arena-{output}''/index.html', 'w', encoding='utf-8') as f:
        f.write(index_html)
    
    print("Created failures/index.html")

if __name__ == "__main__":
    model1 = "claude-sonnet-4.csv"
    res1 = 0
    model2 = "gpt-5-mini.csv"
    res2 = 1
    output = f"{model1}-{res1}-{model2}-{res2}"
    sample_count = analyze_failures(model1, res1, model2, res2, output)
    print(f"\nAnalysis complete! Created {sample_count} HTML files in the 'failures' directory.")


