import transformers
from transformers import AutoTokenizer
from transformers import (
    AutoModelForCausalLM,
)
from transformers import pipeline, set_seed, LogitsProcessor
import torch
import arithmeticcoding
import io
import numpy as np
import peft
from peft import LoraConfig, get_peft_model
import time
from typing import Tuple, List, Dict
import json
import gc
from collections import defaultdict

PRECISION = 32

class TextZipper(object):

    def __init__(self, *args, modelname="facebook/opt-125m", adapter_path = None, **kwargs):
        super().__init__(*args, *kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained(modelname)
        self.model = AutoModelForCausalLM.from_pretrained(modelname, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda()
        if adapter_path:
            self.model.load_adapter(adapter_path)
        self.model = self.model.eval()

        self.precision = PRECISION-3 # less than quarter range to allow for rounding
    
    def encode(self, bitstream, input_text, prompt="",max_length = None):

        if prompt:
            input_text = prompt + " " + input_text
            prompt_end = self.tokenizer([prompt], return_tensors="pt")["attention_mask"].sum() -1
        else:
            prompt_end = 0

        inputs = self.tokenizer([input_text], return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        if max_length is not None:
            input_ids =input_ids[:,:max_length]

        # compute logits
        with torch.no_grad():
            outputs = self.model.forward(input_ids, return_dict=True)
        logits = outputs['logits']
    
        seq_len = input_ids.shape[1]
        for i in range(prompt_end, seq_len):
            with torch.no_grad():
                outputs = self.model.forward(input_ids[:,:i+1], return_dict=True)
            
            scores = outputs['logits'][:,-1]
    
            # patch
            logits[:,i] = scores
    
        probs = logits.softmax(dim=-1)

        V = logits.shape[2]
    
        # compute entropy
        pseq = probs[0, torch.arange(start=prompt_end, end=seq_len-1), input_ids[0, prompt_end+1:]]

        bitout = arithmeticcoding.BitOutputStream(bitstream)
        ac_enc = arithmeticcoding.ArithmeticEncoder(PRECISION, bitout)
    
        seq_len = input_ids.shape[1]
        seq = input_ids[0,1:]

        H = 0.0
        for i in range(prompt_end, seq_len):
            # make a frequency table from probs
            p = probs[0,i]
            f = torch.ceil(p.float() * (2**self.precision)).long().cpu().numpy().tolist()
            freqs = arithmeticcoding.SimpleFrequencyTable(f)
    
            if i == seq_len-1: # last symbol is EOS
                symbol = self.tokenizer.eos_token_id
            else:
                symbol = int(seq[i])
            H += -torch.log2(p[symbol])
            ac_enc.write(freqs, symbol)
        padding = ac_enc.finish(randomize=False)

        return H.item(), padding

    def probs_to_freq(self, probs):
        p = probs[0]
        freqs = torch.ceil(p.float() * 2**self.precision).long().cpu().numpy().tolist()
        freqs = arithmeticcoding.SimpleFrequencyTable(freqs)
        return freqs
    
    def decode_batch(self, bitstreams: List[io.BytesIO], prompt: str = "", max_length: int = 30) -> List[str]:
        device      = self.model.device
        batch_size  = len(bitstreams)
        eos_id      = self.tokenizer.eos_token_id
        decoders    = [arithmeticcoding.ArithmeticDecoder(
                           PRECISION,
                           arithmeticcoding.BitInputStream(bs)) for bs in bitstreams]

        # prompt tokens (can be empty)
        prompt_ids  = self.tokenizer(prompt, return_tensors="pt").input_ids[0].to(device)
        seqs        = [prompt_ids.clone() for _ in range(batch_size)]   # 1‑D LongTensor per seq
        pasts       = [None] * batch_size
        finished    = [False] * batch_size

        def stack_past(past_subset):
            """pad‑and‑stack a list[ past ] → tuple( (k,v) … ) with equal seq_len"""
            num_layers = len(past_subset[0])
            out = []
            for l in range(num_layers):
                k_list = [p[l][0] for p in past_subset]
                v_list = [p[l][1] for p in past_subset]
                mlen   = max(k.shape[2] for k in k_list)
                pad_fn = lambda t: (torch.nn.functional.pad
                                    (t, (0,0, 0, mlen - t.shape[2])))  # pad seq_len to the right
                out.append((torch.cat([pad_fn(k) for k in k_list], dim=0),
                            torch.cat([pad_fn(v) for v in v_list], dim=0)))
            return tuple(out)

        for _ in range(max_length):
            active = [i for i, f in enumerate(finished) if not f]
            if not active:
                break

            # first step: feed full prompt once; later steps: only last token + cached KV
            if pasts[active[0]] is None:             # all actives share this on first step
                inp = torch.stack([seqs[i] for i in active]).to(device)
                out = self.model(inp, use_cache=True, return_dict=True)
            else:
                last_tok = torch.stack([seqs[i][-1:] for i in active]).unsqueeze(1)  # (B,1)
                past_in  = stack_past([pasts[i] for i in active])
                out      = self.model(last_tok,
                                      past_key_values=past_in,
                                      use_cache=True,
                                      return_dict=True)

            probs      = out.logits[:, -1].softmax(-1)      # (B, |V|)
            new_past   = out.past_key_values                # tuple of (k,v)

            # split past back into per‑sequence slices
            for b, idx in enumerate(active):
                pasts[idx] = tuple((kv[0][b:b+1].contiguous(),
                                    kv[1][b:b+1].contiguous()) for kv in new_past)

                sym       = decoders[idx].read(self.probs_to_freq(probs[b:b+1]))
                seqs[idx] = torch.cat([seqs[idx], torch.tensor([sym], device=device)])
                finished[idx] = (sym == eos_id)

        # detokenise
        return [self.tokenizer.decode(s, skip_special_tokens=True).replace(prompt, "").lstrip()
                for s in seqs]


            
    def decode(self, bitstream, prompt = "", max_length=30):
        bitin = arithmeticcoding.BitInputStream(bitstream)
        ac_dec = arithmeticcoding.ArithmeticDecoder(PRECISION, bitin)
    
        # tokenize prompt
        inputs = self.tokenizer([ prompt ], return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
       
        # unroll greedy search loop ourselves
        for ip in range(max_length):

            # compute logits
            with torch.no_grad():
                outputs = self.model.forward(input_ids, return_dict=True)
            scores = outputs['logits'][:,-1]
    
            probs = scores.softmax(dim=-1)
    
            # rebuild freqs
            freqs = self.probs_to_freq(probs)
    
            # decode token
            symbol = ac_dec.read(freqs)
            next_tokens = torch.tensor([symbol], device=self.model.device)

            # append to sequence
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
                
            if symbol == self.tokenizer.eos_token_id:
                break
            
        decoded_text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        decoded_text = decoded_text[0]

        if prompt:
            # drop prompt and space
            decoded_text = decoded_text.replace(prompt, "")[1:]

        return decoded_text

def text_to_bits(text: str, text_zipper, max_bits: int = 256, max_length=30, truncate=True) -> Tuple[torch.Tensor, int]:
    """Convert text to a bit representation using LLMZip."""
    bitstream = io.BytesIO()
    H = text_zipper.encode(bitstream, text, max_length=max_length)
    data = bitstream.getvalue()
    
    # Convert bytes to bits
    bit_array = []
    for byte in data:
        for i in range(7, -1, -1):
            bit_array.append((byte >> i) & 1)
    
    original_length = len(bit_array)
    # print(original_length, text)
    # exit()
    if len(bit_array) > max_bits:
        if truncate:
            bit_array = bit_array[:max_bits]
            print("truncated")
    else:
        bit_array = bit_array + [0] * (max_bits - len(bit_array))
        
    return torch.tensor(bit_array, dtype=torch.float32), original_length

def bits_to_text(bits: torch.Tensor, text_zipper, max_length=30) -> str:
    """Convert bit representation back to text using LLMZip."""
    bits = bits.int().cpu().numpy()
    byte_data = bytearray()
    for i in range(0, len(bits), 8):
        if i + 8 <= len(bits):
            byte = 0
            for j in range(8):
                byte = (byte << 1) | bits[i + j]
            byte_data.append(byte)
    
    bitstream = io.BytesIO(bytes(byte_data))
    decoded_text = text_zipper.decode(bitstream, max_length=max_length)
    
    return decoded_text.split("\n")[0]

def bits_to_text_batch(bits_batch: List[torch.Tensor], text_zipper, max_length=30) -> List[str]:
    """Convert batch of bit representations back to text using LLMZip."""
    bitstreams = []
    
    for bits in bits_batch:
        bits_np = bits.int().cpu().numpy()
        byte_data = bytearray()
        for i in range(0, len(bits_np), 8):
            if i + 8 <= len(bits_np):
                byte = 0
                for j in range(8):
                    byte = (byte << 1) | bits_np[i + j]
                byte_data.append(byte)
        
        bitstream = io.BytesIO(bytes(byte_data))
        bitstreams.append(bitstream)
    
    decoded_texts = text_zipper.decode_batch(bitstreams, max_length=max_length)
    return [text.split("\n")[0] for text in decoded_texts]

class TextZipperBenchmark:
    def __init__(self, text_zipper, device='cuda'):
        self.text_zipper = text_zipper
        self.device = device
        
    def benchmark_single_decoding(self, test_texts, max_length=30, num_runs=100):
        """Benchmark single sequence decoding."""
        print("\n" + "="*60)
        print("SINGLE SEQUENCE DECODING BENCHMARK")
        print("="*60)
        
        # Prepare encoded data
        encoded_data = []
        for text in test_texts:
            try:
                bit_msg, length = text_to_bits(text, self.text_zipper, max_bits=256, max_length=max_length)
                encoded_data.append((bit_msg, text, length))
            except Exception as e:
                print(f"Failed to encode text: {text[:50]}... Error: {e}")
                continue
        
        print(f"Successfully encoded {len(encoded_data)} texts for benchmarking")
        
        metrics = {
            'decode_times': [],
            'memory_usage': [],
            'tokens_per_second': [],
            'successful_decodes': 0,
            'failed_decodes': 0,
            'total_tokens_generated': 0,
            'model_forward_times': [],
            'tokenizer_times': [],
            'arithmetic_decode_times': []
        }
        
        # Warmup
        print("Warming up...")
        for i in range(min(5, len(encoded_data))):
            try:
                bits_to_text(encoded_data[i][0].squeeze(), self.text_zipper, max_length=max_length)
            except:
                pass
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print(f"Starting benchmark with {num_runs} runs...")
        
        for run in range(num_runs):
            if run % 20 == 0:
                print(f"Run {run}/{num_runs}")
            
            bit_msg, original_text, original_length = encoded_data[run % len(encoded_data)]
            bits = bit_msg.squeeze()
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                memory_before = torch.cuda.memory_allocated()
            
            start_time = time.perf_counter()
            
            try:
                decoded_text = self.benchmark_decode_detailed(bits, max_length, metrics)
                
                end_time = time.perf_counter()
                decode_time = end_time - start_time
                
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                    memory_after = torch.cuda.memory_allocated()
                    memory_used = memory_after - memory_before
                else:
                    memory_used = 0
                
                estimated_tokens = len(decoded_text.split())
                tokens_per_sec = estimated_tokens / decode_time if decode_time > 0 else 0
                
                metrics['decode_times'].append(decode_time)
                metrics['memory_usage'].append(memory_used)
                metrics['tokens_per_second'].append(tokens_per_sec)
                metrics['successful_decodes'] += 1
                metrics['total_tokens_generated'] += estimated_tokens
                
            except Exception as e:
                print(f"Decode failed for run {run}: {e}")
                metrics['failed_decodes'] += 1
                continue
        
        self.print_single_results(metrics, num_runs)
        return metrics
    
    # drop the old benchmark_batch_decoding() and paste this in TextZipperBenchmark
    def benchmark_batch_decoding(self, texts, batch_sizes=(1,2,4,8,16,32), runs=5, max_length=30):
        enc = [text_to_bits(t, self.text_zipper, 256, max_length)[0].squeeze() for t in texts]
        enc *= ((max(batch_sizes)*runs)//len(enc)+1)  # recycle to cover all runs
        base = None
        for B in batch_sizes:
            t_sum = 0.0
            for r in range(runs):
                batch = [enc[r*B+j] for j in range(B)]
                if torch.cuda.is_available(): torch.cuda.synchronize()
                t0 = time.perf_counter()
                _ = bits_to_text_batch(batch, self.text_zipper, max_length)
                if torch.cuda.is_available(): torch.cuda.synchronize()
                t_sum += (time.perf_counter() - t0)
            avg = t_sum / runs
            if base is None: base = avg           # B == 1 reference
            print(f"B={B:<3} {1000*avg/B:.2f} ms/seq | {B/avg:.1f} seq/s | ×{(base/avg)*B:.1f}")
    
    def benchmark_decode_detailed(self, bits, max_length, metrics):
        """Detailed decode function with component timing."""
        bits_np = bits.int().cpu().numpy()
        byte_data = bytearray()
        for i in range(0, len(bits_np), 8):
            if i + 8 <= len(bits_np):
                byte = 0
                for j in range(8):
                    byte = (byte << 1) | bits_np[i + j]
                byte_data.append(byte)
        
        bitstream = io.BytesIO(bytes(byte_data))
        bitin = arithmeticcoding.BitInputStream(bitstream)
        ac_dec = arithmeticcoding.ArithmeticDecoder(PRECISION, bitin)
        
        inputs = self.text_zipper.tokenizer([""], return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.text_zipper.model.device)
        
        model_forward_time = 0
        tokenizer_time = 0
        arithmetic_time = 0
        
        for ip in range(max_length):
            forward_start = time.perf_counter()
            with torch.no_grad():
                outputs = self.text_zipper.model.forward(input_ids, return_dict=True)
            scores = outputs['logits'][:,-1]
            probs = scores.softmax(dim=-1)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            model_forward_time += time.perf_counter() - forward_start
            
            arith_start = time.perf_counter()
            freqs = self.text_zipper.probs_to_freq(probs)
            symbol = ac_dec.read(freqs)
            arithmetic_time += time.perf_counter() - arith_start
            
            next_tokens = torch.tensor([symbol], device=self.text_zipper.model.device)
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            
            if symbol == self.text_zipper.tokenizer.eos_token_id:
                break
        
        tokenizer_start = time.perf_counter()
        decoded_text = self.text_zipper.tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        tokenizer_time = time.perf_counter() - tokenizer_start
        
        metrics['model_forward_times'].append(model_forward_time)
        metrics['tokenizer_times'].append(tokenizer_time)
        metrics['arithmetic_decode_times'].append(arithmetic_time)
        
        return decoded_text
    
    def print_single_results(self, metrics, num_runs):
        """Print single sequence benchmark results."""
        print("\n" + "="*60)
        print("SINGLE SEQUENCE RESULTS")
        print("="*60)
        
        if metrics['successful_decodes'] == 0:
            print("No successful decodes!")
            return
        
        decode_times = np.array(metrics['decode_times'])
        tokens_per_sec = np.array(metrics['tokens_per_second'])
        
        print(f"Total runs: {num_runs}")
        print(f"Successful decodes: {metrics['successful_decodes']}")
        print(f"Success rate: {metrics['successful_decodes']/num_runs*100:.1f}%")
        print(f"Mean decode time: {decode_times.mean()*1000:.2f}ms")
        print(f"Mean throughput: {tokens_per_sec.mean():.2f} tokens/sec")
        print(f"Overall throughput: {metrics['total_tokens_generated']/decode_times.sum():.2f} tokens/sec")
        
        if metrics['model_forward_times']:
            forward_times = np.array(metrics['model_forward_times'])
            arith_times = np.array(metrics['arithmetic_decode_times'])
            
            print(f"\nComponent breakdown:")
            print(f"Model forward: {forward_times.mean()*1000:.2f}ms ({forward_times.mean()/decode_times.mean()*100:.1f}%)")
            print(f"Arithmetic decode: {arith_times.mean()*1000:.2f}ms ({arith_times.mean()/decode_times.mean()*100:.1f}%)")
    
    def print_batch_results(self, batch_size, metrics, num_runs):
        """Print batch benchmark results."""
        if metrics['successful_batches'] == 0:
            print(f"Batch size {batch_size}: No successful batches!")
            return
        
        decode_times = np.array(metrics['decode_times'])
        tokens_per_sec = np.array(metrics['tokens_per_second'])
        sequences_per_sec = np.array(metrics['sequences_per_second'])
        
        print(f"Batch size {batch_size}:")
        print(f"  Success rate: {metrics['successful_batches']/num_runs*100:.1f}%")
        print(f"  Mean batch time: {decode_times.mean()*1000:.2f}ms")
        print(f"  Mean time per sequence: {decode_times.mean()/batch_size*1000:.2f}ms")
        print(f"  Throughput: {tokens_per_sec.mean():.2f} tokens/sec")
        print(f"  Sequences/sec: {sequences_per_sec.mean():.2f}")
        print(f"  Speedup vs batch_size=1: {sequences_per_sec.mean():.2f}x")
        
        if torch.cuda.is_available():
            memory_usage = np.array(metrics['memory_usage'])
            print(f"  Memory per batch: {memory_usage.mean()/1024/1024:.2f}MB")
            print(f"  Memory per sequence: {memory_usage.mean()/batch_size/1024/1024:.2f}MB")
    
    def print_batch_comparison(self, batch_results):
        """Print comparison across batch sizes."""
        print("\n" + "="*60)
        print("BATCH SIZE COMPARISON")
        print("="*60)
        
        print(f"{'Batch Size':<10} {'Time/Seq(ms)':<12} {'Tokens/Sec':<12} {'Seq/Sec':<10} {'Speedup':<8} {'Memory/Seq(MB)':<15}")
        print("-" * 75)
        
        baseline_seq_per_sec = None
        
        for batch_size in sorted(batch_results.keys()):
            metrics = batch_results[batch_size]
            
            if metrics['successful_batches'] == 0:
                continue
                
            decode_times = np.array(metrics['decode_times'])
            tokens_per_sec = np.array(metrics['tokens_per_second'])
            sequences_per_sec = np.array(metrics['sequences_per_second'])
            
            time_per_seq = decode_times.mean() / batch_size * 1000
            tokens_per_sec_mean = tokens_per_sec.mean()
            seq_per_sec_mean = sequences_per_sec.mean()
            
            if baseline_seq_per_sec is None:
                baseline_seq_per_sec = seq_per_sec_mean
                speedup = 1.0
            else:
                speedup = seq_per_sec_mean / baseline_seq_per_sec
            
            if torch.cuda.is_available() and len(metrics['memory_usage']) > 0:
                memory_usage = np.array(metrics['memory_usage'])
                memory_per_seq = memory_usage.mean() / batch_size / 1024 / 1024
            else:
                memory_per_seq = 0
            
            print(f"{batch_size:<10} {time_per_seq:<12.2f} {tokens_per_sec_mean:<12.2f} {seq_per_sec_mean:<10.2f} {speedup:<8.2f} {memory_per_seq:<15.2f}")

def create_test_texts(num_texts=50):
    """Create diverse test texts for benchmarking."""
    test_texts = [
        "The cat is screaming in the research center named INRIA at Los Angeles",
        "Machine learning models are revolutionizing artificial intelligence research",
        "The quick brown fox jumps over the lazy dog",
        "Climate change represents one of the greatest challenges facing humanity",
        "Artificial intelligence will transform society in unprecedented ways",
        "The ocean waves crashed against the rocky shore with tremendous force",
        "Education is the most powerful weapon which you can use to change the world",
        "Space exploration continues to push the boundaries of human knowledge",
        "The ancient library contained thousands of rare manuscripts and scrolls",
        "Innovation drives progress and creates new opportunities for growth",
        "The forest was dense with towering trees and abundant wildlife",
        "Technology advances faster than human adaptation can keep pace",
        "The sun was shining brightly in the clear blue sky today",
        "Scientific research requires patience, precision, and creativity",
        "The city streets were bustling with activity and energy",
        "Music has the power to transcend cultural and linguistic barriers",
        "The mountains stood majestically against the horizon",
        "Democracy requires active participation from all citizens",
        "The digital revolution has transformed how we communicate",
        "Art reflects the spirit and values of its time",
    ]
    
    extended_texts = test_texts.copy()
    for i in range(num_texts - len(test_texts)):
        text1 = test_texts[i % len(test_texts)]
        text2 = test_texts[(i + 1) % len(test_texts)]
        combined = text1[:len(text1)//2] + " " + text2[len(text2)//2:]
        extended_texts.append(combined)
    
    return extended_texts[:num_texts]

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    MAX_LENGTH = 30
    test_texts = ["At West Palm Beach, the majority of the damage was confined to vegetation. Several coconut and royal palms that withstood the 1928 hurricane snapped, littering streets with broken trunks. Winds"]
    # Initialize text zipper
    text_zipper = TextZipper()
    text_zipper.model.to(device)
    print(f"Initialized LLMZip text compressor on {device}")
    
    # Create benchmark instance
    benchmark = TextZipperBenchmark(text_zipper, device)
    
    # Create test texts
    # test_texts = create_test_texts(num_texts=30)
    
    # Run comprehensive benchmark
    print("\n" + "="*80)
    print("COMPREHENSIVE TEXTZIPPER BENCHMARK")
    print("="*80)
    
    all_results = {}
    
    # 1. Single sequence benchmark
    print("\n1. Running single sequence benchmark...")
    single_metrics = benchmark.benchmark_single_decoding(
        test_texts=test_texts,
        max_length=MAX_LENGTH,
        num_runs=100
    )
    all_results['single_sequence'] = single_metrics
    
    # 2. Batch benchmark
    print("\n2. Running batch benchmark...")
    batch_metrics = benchmark.benchmark_batch_decoding(
        texts=test_texts,
        batch_sizes=[1, 2, 4, 8, 16, 32,64,128,256],
        max_length=MAX_LENGTH,
        runs=5
    )
    all_results['batch_decoding'] = batch_metrics
    
    # Save results
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    results_file = f"textzipper_comprehensive_benchmark_{timestamp}.json"
    
    # Convert numpy arrays to lists for JSON serialization
    def convert_numpy_to_list(obj):
        if isinstance(obj, dict):
            return {k: convert_numpy_to_list(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_numpy_to_list(item) for item in obj]
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.int64, np.int32, np.float64, np.float32)):
            return float(obj)
        else:
            return obj
    
    json_results = convert_numpy_to_list(all_results)
    
    with open(results_file, 'w') as f:
        json.dump(json_results, f, indent=2)
    
    print(f"\n{'='*80}")
    print("COMPREHENSIVE BENCHMARK COMPLETE!")
    print(f"Results saved to: {results_file}")
    print("="*80)