import transformers
from transformers import AutoTokenizer
from transformers import (
    AutoModelForCausalLM,
)
from transformers import pipeline, set_seed, LogitsProcessor

import torch
import arithmeticcoding
import io
import numpy as np
import peft
from peft import LoraConfig, get_peft_model
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter
import seaborn as sns
import json
# Set matplotlib to use LaTeX-compatible settings
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'serif',
    'text.usetex': False,  # Set to True if you have LaTeX installed
    'figure.figsize': (10, 8),
    'axes.linewidth': 1,
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'xtick.major.size': 7,
    'xtick.minor.size': 4,
    'ytick.major.size': 7,
    'ytick.minor.size': 4,
    'legend.frameon': True,
    'legend.fancybox': True,
    'legend.shadow': True
})

PRECISION = 32

class TextZipper(object):

    def __init__(self, *args, modelname="facebook/opt-350m", adapter_path = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(modelname)
        self.model = AutoModelForCausalLM.from_pretrained(modelname, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda()
        if adapter_path:
            self.model.load_adapter(adapter_path)

        self.model = self.model.eval()

        self.precision = PRECISION-3 # less than quarter range to allow for rounding
    
    def encode(self, bitstream, input_text, prompt="", max_length=None):

        if prompt:
            input_text = prompt + " " + input_text
            prompt_end = self.tokenizer([prompt], return_tensors="pt")["attention_mask"].sum() -1
        else:
            prompt_end = 0

        inputs = self.tokenizer([input_text], return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        if max_length is not None:
            input_ids = input_ids[:,:max_length]

        # compute logits
        with torch.no_grad():
            outputs = self.model.forward(input_ids, return_dict=True)
        logits = outputs['logits']
    
        seq_len = input_ids.shape[1]
        for i in range(prompt_end, seq_len):
            with torch.no_grad():
                outputs = self.model.forward(input_ids[:,:i+1], return_dict=True)
            
            scores = outputs['logits'][:,-1]
    
            # patch
            logits[:,i] = scores
    
        probs = logits.softmax(dim=-1)

        V = logits.shape[2]
    
        # compute entropy
        pseq = probs[0, torch.arange(start=prompt_end, end=seq_len-1), input_ids[0, prompt_end+1:]]

        bitout = arithmeticcoding.BitOutputStream(bitstream)
        ac_enc = arithmeticcoding.ArithmeticEncoder(PRECISION, bitout)
    
        seq_len = input_ids.shape[1]
        seq = input_ids[0,1:]

        H = 0.0
        for i in range(prompt_end, seq_len):
            # make a frequency table from probs
            p = probs[0,i]
            f = torch.ceil(p.float() * (2**self.precision)).long().cpu().numpy().tolist()
            freqs = arithmeticcoding.SimpleFrequencyTable(f)
    
            if i == seq_len-1: # last symbol is EOS
                symbol = self.tokenizer.eos_token_id
            else:
                symbol = int(seq[i])
            H += -torch.log2(p[symbol])
            ac_enc.write(freqs, symbol)
        padding = ac_enc.finish(randomize=False)

        return H.item(), padding

    def probs_to_freq(self, probs):
        p = probs[0]
        freqs = torch.ceil(p.float() * 2**self.precision).long().cpu().numpy().tolist()
        freqs = arithmeticcoding.SimpleFrequencyTable(freqs)
        return freqs
def compute_bit_length_distribution_json(json_path, N, modelname="facebook/opt-125m", adapter_path=None):
    """
    Compute bit length distribution for the first N tokens of each sample's "original" text in a JSON file.
    The JSON file is expected to have a structure like: {'root': {'texts': [{'original': 'text1'}, {'original': 'text2'}, ...]}}
    """
    
    # Initialize the TextZipper
    zipper = TextZipper(modelname=modelname, adapter_path=adapter_path)
    
    # Read and parse the JSON file
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # Access the list of text items from the specified structure
    try:
        # The image shows a root -> texts structure
        text_items = data['texts']
    except KeyError:
        print("Error: JSON file must have the structure {'root': {'texts': [...]}}")
        return None
    
    print(f"Found {len(text_items)} samples in JSON file.")
    print(f"Processing samples with first {N} tokens each...")
    
    bit_lengths = []
    
    for idx, item in enumerate(text_items):
        if idx > 0 and idx % 100 == 0:
            print(f"Processing sample {idx}/{len(text_items)}")
            
        # Get the text from the "original" key
        if 'original' not in item:
            print(f"Warning: 'original' key not found in sample {idx}. Skipping.")
            continue
            
        text = str(item['original'])
        
        try:
            # Create a bitstream to store compressed data
            bitstream = io.BytesIO()
            
            # Encode the text with max_length=N to limit to first N tokens
            entropy, padding = zipper.encode(bitstream, text, max_length=N)
            
            # Get the actual bit length (entropy + padding)
            bit_length = entropy + padding
            bit_lengths.append(bit_length)
            
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            continue

    if not bit_lengths:
        print("No samples were successfully processed.")
        return {'bit_lengths': [], 'distribution': {}, 'stats': {}}

    # Compute distribution
    bit_length_counter = Counter(bit_lengths)
    
    # Convert to sorted list for easier analysis
    sorted_lengths = sorted(bit_length_counter.keys())
    distribution = {length: bit_length_counter[length] for length in sorted_lengths}
    
    return {
        'bit_lengths': bit_lengths,
        'distribution': distribution,
        'stats': {
            'mean': np.mean(bit_lengths),
            'std': np.std(bit_lengths),
            'min': np.min(bit_lengths),
            'max': np.max(bit_lengths),
            'median': np.median(bit_lengths),
            'q25': np.percentile(bit_lengths, 25),
            'q75': np.percentile(bit_lengths, 75)
        }
    }

def compute_bit_length_distribution(csv_path, N, modelname="facebook/opt-125m", adapter_path=None):
    """
    Compute bit length distribution for first N tokens of each sample in WikiText dataset
    """
    
    # Initialize the TextZipper
    zipper = TextZipper(modelname=modelname, adapter_path=adapter_path)
    
    # Read the CSV file
    df = pd.read_csv(csv_path)
    
    # Assuming the text column is named 'text' - adjust if needed
    if 'text' not in df.columns:
        # Try to find the text column
        text_columns = [col for col in df.columns if 'text' in col.lower()]
        if text_columns:
            text_column = text_columns[0]
        else:
            # Use the first column that seems to contain text
            text_column = df.columns[0]
    else:
        text_column = 'text'
    
    print(f"Using column '{text_column}' as text column")
    print(f"Processing {len(df)} samples with first {N} tokens each...")
    
    bit_lengths = []
    
    for idx, row in df.iterrows():
        if idx % 100 == 0:
            print(f"Processing sample {idx}/{len(df)}")
            
        text = str(row[text_column])
        
        try:
            # Create a bitstream to store compressed data
            bitstream = io.BytesIO()
            
            # Encode the text with max_length=N to limit to first N tokens
            entropy, padding = zipper.encode(bitstream, text, max_length=N)
            
            # Get the actual bit length (entropy + padding)
            bit_length = entropy + padding
            bit_lengths.append(bit_length)
            
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            continue
    
    # Compute distribution
    bit_length_counter = Counter(bit_lengths)
    
    # Convert to sorted list for easier analysis
    sorted_lengths = sorted(bit_length_counter.keys())
    distribution = {length: bit_length_counter[length] for length in sorted_lengths}
    
    return {
        'bit_lengths': bit_lengths,
        'distribution': distribution,
        'stats': {
            'mean': np.mean(bit_lengths),
            'std': np.std(bit_lengths),
            'min': np.min(bit_lengths),
            'max': np.max(bit_lengths),
            'median': np.median(bit_lengths),
            'q25': np.percentile(bit_lengths, 25),
            'q75': np.percentile(bit_lengths, 75)
        }
    }

def plot_histogram(results, N, save_path):
    """Create a histogram plot"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create histogram
    n, bins, patches = ax.hist(results['bit_lengths'], bins=50, alpha=0.7, 
                              color='steelblue', edgecolor='black', linewidth=0.5)
    
    # Add vertical lines for mean and median
    ax.axvline(results['stats']['mean'], color='red', linestyle='--', 
               linewidth=2, label=f"Mean: {results['stats']['mean']:.2f}")
    ax.axvline(results['stats']['median'], color='orange', linestyle='--', 
               linewidth=2, label=f"Median: {results['stats']['median']:.2f}")
    
    ax.set_xlabel('Bit Length')
    ax.set_ylabel('Frequency')
    ax.set_title(f'Bit Length Distribution (First {N} Tokens)')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Histogram saved to {save_path}")

def plot_boxplot(results, N, save_path):
    """Create a box plot"""
    fig, ax = plt.subplots(figsize=(10, 4))
    
    bp = ax.boxplot(results['bit_lengths'], vert=False, patch_artist=True,
                    boxprops=dict(facecolor='lightblue', alpha=0.7),
                    medianprops=dict(color='red', linewidth=2))
    
    ax.set_xlabel('Bit Length')
    ax.set_title(f'Box Plot of Bit Lengths (First {N} Tokens)')
    ax.grid(True, alpha=0.3)
    
    # Add statistics text
    stats = results['stats']
    stats_text = (f"Mean: {stats['mean']:.2f}\n"
                  f"Std: {stats['std']:.2f}\n"
                  f"Min: {stats['min']:.2f}\n"
                  f"Max: {stats['max']:.2f}\n"
                  f"Q25: {stats['q25']:.2f}\n"
                  f"Median: {stats['median']:.2f}\n"
                  f"Q75: {stats['q75']:.2f}")
    
    ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
            verticalalignment='top', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Box plot saved to {save_path}")

def plot_density(results, N, save_path):
    """Create a density plot"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create density plot using seaborn
    sns.histplot(data=results['bit_lengths'], kde=True, stat='density', 
                alpha=0.6, color='steelblue', ax=ax)
    
    # Add vertical lines for mean and median
    ax.axvline(results['stats']['mean'], color='red', linestyle='--', 
               linewidth=2, label=f"Mean: {results['stats']['mean']:.2f}")
    ax.axvline(results['stats']['median'], color='orange', linestyle='--', 
               linewidth=2, label=f"Median: {results['stats']['median']:.2f}")
    
    ax.set_xlabel('Bit Length')
    ax.set_ylabel('Density')
    ax.set_title(f'Bit Length Density Distribution (First {N} Tokens)')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Density plot saved to {save_path}")

def plot_cumulative_distribution(results, N, save_path):
    """Create a cumulative distribution plot"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    sorted_lengths = np.sort(results['bit_lengths'])
    cumulative_prob = np.arange(1, len(sorted_lengths) + 1) / len(sorted_lengths)
    
    ax.plot(sorted_lengths, cumulative_prob, linewidth=2, color='steelblue')
    ax.fill_between(sorted_lengths, cumulative_prob, alpha=0.3, color='steelblue')
    
    # Add horizontal lines for quartiles
    ax.axhline(0.25, color='gray', linestyle=':', alpha=0.7, label='Q1')
    ax.axhline(0.5, color='gray', linestyle=':', alpha=0.7, label='Median')
    ax.axhline(0.75, color='gray', linestyle=':', alpha=0.7, label='Q3')
    
    # Add vertical lines for quartile values
    ax.axvline(results['stats']['q25'], color='gray', linestyle=':', alpha=0.7)
    ax.axvline(results['stats']['median'], color='gray', linestyle=':', alpha=0.7)
    ax.axvline(results['stats']['q75'], color='gray', linestyle=':', alpha=0.7)
    
    ax.set_xlabel('Bit Length')
    ax.set_ylabel('Cumulative Probability')
    ax.set_title(f'Cumulative Distribution Function (First {N} Tokens)')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Cumulative distribution plot saved to {save_path}")

def plot_summary_statistics(results, N, save_path):
    """Create a summary statistics visualization"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
    
    # 1. Histogram
    ax1.hist(results['bit_lengths'], bins=30, alpha=0.7, color='steelblue', edgecolor='black')
    ax1.axvline(results['stats']['mean'], color='red', linestyle='--', label='Mean')
    ax1.set_xlabel('Bit Length')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Histogram')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Box plot
    bp = ax2.boxplot(results['bit_lengths'], patch_artist=True,
                     boxprops=dict(facecolor='lightblue', alpha=0.7))
    ax2.set_ylabel('Bit Length')
    ax2.set_title('Box Plot')
    ax2.grid(True, alpha=0.3)
    
    # 3. Q-Q plot (quantile-quantile plot against normal distribution)
    from scipy import stats
    stats.probplot(results['bit_lengths'], dist="norm", plot=ax3)
    ax3.set_title('Q-Q Plot (vs Normal Distribution)')
    ax3.grid(True, alpha=0.3)
    
    # 4. Statistics table
    ax4.axis('tight')
    ax4.axis('off')
    stats_data = [
        ['Statistic', 'Value'],
        ['Count', f"{len(results['bit_lengths'])}"],
        ['Mean', f"{results['stats']['mean']:.3f}"],
        ['Std Dev', f"{results['stats']['std']:.3f}"],
        ['Min', f"{results['stats']['min']:.3f}"],
        ['Q25', f"{results['stats']['q25']:.3f}"],
        ['Median', f"{results['stats']['median']:.3f}"],
        ['Q75', f"{results['stats']['q75']:.3f}"],
        ['Max', f"{results['stats']['max']:.3f}"],
    ]
    
    table = ax4.table(cellText=stats_data[1:], colLabels=stats_data[0],
                      cellLoc='center', loc='center', bbox=[0, 0, 1, 1])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)
    ax4.set_title('Summary Statistics')
    
    plt.suptitle(f'Bit Length Analysis Summary (First {N} Tokens)', fontsize=16)
    plt.tight_layout()
    plt.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Summary statistics plot saved to {save_path}")

def save_results(results, N, output_path):
    """Save results to CSV"""
    
    # Create a DataFrame with individual bit lengths
    df_results = pd.DataFrame({
        'sample_index': range(len(results['bit_lengths'])),
        'bit_length': results['bit_lengths']
    })
    
    df_results.to_csv(output_path, index=False)
    
    # Save summary statistics separately
    stats_path = output_path.replace('.csv', '_summary.csv')
    stats_df = pd.DataFrame([results['stats']])
    stats_df.to_csv(stats_path, index=False)
    
    print(f"Results saved to {output_path}")
    print(f"Summary statistics saved to {stats_path}")
    
def plot_percentage_above_256(results_dict, save_path):
    """
    Plot percentage of samples with bit length above 256 for different N values
    
    Args:
        results_dict: Dictionary with N values as keys and results as values
        save_path: Path to save the plot
    """
    N_values = []
    percentages = []
    
    for N in sorted(results_dict.keys()):
        results = results_dict[N]
        bit_lengths = results['bit_lengths']
        
        # Count samples with bit length > 256
        above_256 = sum(1 for length in bit_lengths if length > 256)
        total_samples = len(bit_lengths)
        percentage = (above_256 / total_samples) * 100
        
        N_values.append(N)
        percentages.append(percentage)
        
        print(f"N={N}: {above_256}/{total_samples} samples ({percentage:.2f}%) above 256 bits")
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot line with markers
    ax.plot(N_values, percentages, marker='o', linewidth=2, markersize=8, 
            color='steelblue', markerfacecolor='lightblue', markeredgecolor='steelblue')
    
    # Fill area under the curve
    ax.fill_between(N_values, percentages, alpha=0.3, color='steelblue')
    
    # Add value labels on each point
    for i, (n, pct) in enumerate(zip(N_values, percentages)):
        ax.annotate(f'{pct:.1f}%', (n, pct), textcoords="offset points", 
                   xytext=(0,10), ha='center', fontsize=10, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7))
    
    ax.set_xlabel('Number of Tokens (N)')
    ax.set_ylabel('Percentage of Samples Above 256 Bits (%)')
    ax.set_title('Percentage of Samples with Bit Length > 256')
    ax.grid(True, alpha=0.3)
    
    # Set x-axis to show all N values
    ax.set_xticks(N_values)
    
    # Set y-axis to start from 0 and add some padding at the top
    ax.set_ylim(0, max(percentages) * 1.1)
    
    plt.tight_layout()
    plt.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Percentage above 256 plot saved to {save_path}")
    
    # Return the data for further analysis
    return dict(zip(N_values, percentages))
# Main execution
if __name__ == "__main__":
    # Set parameters
    CSV_PATH = "test_splits/pixmo_challenge.csv" #"/home/gevennou/text_reconstruction/trial/finetune_llm_results/test_sets/wikitext_test.csv"
    
    for N in [10,15,20,25,30,35,40,45,50]:    # put aside as already computed for pixmo, put them back for another dataset
        # Create output directory for figures
        import os
        output_dir = f"bit_length_pixmo_challenge/bit_length_analysis_N{N}"
        os.makedirs(output_dir, exist_ok=True)
        
        # Compute bit length distribution
        print(f"Computing bit length distribution for first {N} tokens...")
        results = compute_bit_length_distribution(CSV_PATH,N) #if using csv file
        # Print summary statistics
        print("\nSummary Statistics:")
        for key, value in results['stats'].items():
            print(f"{key}: {value:.3f}")
        
        # Create all plots as PDFs
        print("\nGenerating PDF plots...")
        
        plot_histogram(results, N, 
                       os.path.join(output_dir, f"histogram_N{N}.pdf"))
        
        plot_boxplot(results, N, 
                     os.path.join(output_dir, f"boxplot_N{N}.pdf"))
        
        plot_density(results, N, 
                     os.path.join(output_dir, f"density_N{N}.pdf"))
        
        plot_cumulative_distribution(results, N, 
                                    os.path.join(output_dir, f"cumulative_N{N}.pdf"))
        
        plot_summary_statistics(results, N, 
                               os.path.join(output_dir, f"summary_N{N}.pdf"))
        
        # Save detailed results
        save_results(results, N, os.path.join(output_dir, f"bit_length_results_N{N}.csv"))
        
        # Print final summary
        print(f"\nProcessed {len(results['bit_lengths'])} samples")
        print(f"Unique bit lengths: {len(results['distribution'])}")
        print(f"All files saved in directory: {output_dir}")
        