import os
import random
import argparse
import json
import torch
from typing import List, Dict, Any
from tqdm import tqdm
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from human_eval.data import read_problems
from human_eval.data import write_jsonl, read_problems, stream_jsonl


def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    
def compute_cross_entropy_loss(model, input_ids):
    outputs = model(input_ids=input_ids, labels=input_ids)
    return outputs.loss

def calc_intensity(layer_probs,topk):
    total_intensity = 0
    seq_len = layer_probs.shape[0]
    
    for seq_pos in range(seq_len):
        topk_indices = np.argsort(layer_probs[seq_pos])[-topk:]
        topk_weights = layer_probs[seq_pos, topk_indices]
        total_intensity += np.sum(-np.log(topk_weights + 1e-12))
    
    avg_intensity = total_intensity / seq_len
    return avg_intensity

def setup_layer_weights(model, model_name, input_ids):
    model_config = {
        'olmoe': [3, 16, 8],
        'deepseek': [5, 26, 6],
        'qwen3': [8, 48, 8]
    }
    
    model.eval()
    compute_cross_entropy_loss(model, input_ids)
    scores = model.get_scores()
    scores = torch.stack(scores).detach().cpu().numpy()
    
    n_layer, n_token, n_expert = scores.shape
    layer_score = np.zeros((n_layer, 3))
    
    use_start_idx = model_config[model_name][0]
    layer_len = model_config[model_name][1]
    topk_num = model_config[model_name][2]
    
    for i_nlayer in range(n_layer):
        layer_score[i_nlayer, 1] = calc_intensity(scores[i_nlayer], topk_num)

    score_min = np.min(layer_score[:, 1])
    score_max = np.max(layer_score[:, 1])

    use_idx = [f+use_start_idx  for f in range(layer_len - use_start_idx)]
    if model_name == 'deepseek':
        use_idx = [cur_use_idx +1  for cur_use_idx in use_idx]
        n_layer+=1
        print(f"use: {use_idx}")

        ece_weights = dict()
        for i_nlayers in range(1,n_layer):
            if i_nlayers in use_idx:
                ece_weights[i_nlayers] = (layer_score[i_nlayers-1, 1] - score_min) / (score_max - score_min + 1e-6)
            else:
                ece_weights[i_nlayers] = 0
        
        print(f"Before quantile, ece_weights: {ece_weights}")
    return ece_weights, use_idx


def ttf_inference_single_fewshot_progressive_v2(
    model,
    tokenizer,
    generation_params,
    inputs,
    steps, 
    model_name,
    ttf_lr,
    chunk_size=50
):
    for param in model.parameters():
        param.requires_grad = False
    layer_indices = {i for i, layer in enumerate(model.model.layers)
                    if hasattr(layer.mlp, 'gate')}
    model.enable_delta_for_layers(layer_indices)
    
    original_input_ids = inputs["input_ids"]
    original_length = original_input_ids.shape[1]
    
    print("=== Phase 1: Prompt Optimization with Cross Entropy ===")
    current_input_ids = original_input_ids.clone()
    
    # Reset deltas to zero for initial optimization
    for param in model.get_delta_parameters():
        param.data.zero_()
    
    ece_weights, use_idx = setup_layer_weights(model, model_name, current_input_ids)
    model.set_weights(layer_indices, ece_weights)
    model.train()
    delta_params = model.get_delta_parameters()
    optimizer = torch.optim.AdamW(delta_params, lr=ttf_lr, weight_decay=0.00001)


    print(f"Starting prompt optimization with Cross Entropy for {steps} steps...")
    for step in range(steps):
        loss = compute_cross_entropy_loss(model, current_input_ids)
        print(f"Prompt TTF step {step+1}/{steps}, Cross Entropy Loss: {loss.item():.6f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        torch.cuda.empty_cache()
    
    print("=== Phase 2: Progressive Generation with Uncertainty Optimization ===")
    model.eval()
    max_new_tokens = generation_params.max_new_tokens
    generated_tokens = 0

    gen_params = vars(generation_params).copy()
    gen_params['max_new_tokens'] = chunk_size
    gen_params['do_sample'] = True
    
    if max_new_tokens is None:
        print("No max_new_tokens limit set, will only stop on EOS token or generation failure")
        max_new_tokens = float('inf')
    
    generation_round = 0
    while generated_tokens < max_new_tokens:
        generation_round += 1
        print(f"\n--- Generation Round {generation_round} ---")
        if max_new_tokens == float('inf'):
            print(f"Generated tokens so far: {generated_tokens} (no limit)")
        else:
            print(f"Generated tokens so far: {generated_tokens}/{max_new_tokens}")

        with torch.no_grad():
            chunk_outputs = model.generate(
                input_ids=current_input_ids,
                **gen_params,
            )
        chunk_outputs = chunk_outputs.detach()
        new_tokens = chunk_outputs[0][current_input_ids.shape[1]:]
        
        if len(new_tokens) == 0:
            print("No new tokens generated, stopping.")
            break
            
        current_input_ids = chunk_outputs[0:1]
        generated_tokens += len(new_tokens)
        
        new_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
        print(f"Generated text: {new_text}")
        
        if tokenizer.eos_token_id in new_tokens:
            print(f"EOS token detected, stopping generation.")
            break
        
        if generated_tokens < max_new_tokens and steps > 0:
            print(f"Starting uncertainty optimization after generating {len(new_tokens)} tokens...")
            
            ece_weights, use_idx = setup_layer_weights(model, model_name, current_input_ids)
            model.set_weights(layer_indices, ece_weights)

            model.train()

            for step in range(steps):
                generation_loss = compute_cross_entropy_loss(model, current_input_ids)
                print(f"Round {generation_round} generation step {step+1}/{steps}, "
                      f"generation_loss: {generation_loss.item():.6f}")
                optimizer.zero_grad()
                loss = generation_loss
                loss.backward(retain_graph=False)
                optimizer.step()
                torch.cuda.empty_cache()
                
            model.eval()
    
    print(f"\n=== Generation Complete ===")
    print(f"Total tokens generated: {generated_tokens}")
    print(f"Final sequence length: {current_input_ids.shape[1]}")
    delta_params = model.get_delta_parameters()
    return current_input_ids, delta_params


def create_chat_prompt(system_prompt: str, fewshot_examples: List[Dict], test_question: str, tokenizer, model_name: str) -> str:
    messages = [{"role": "system", "content": system_prompt}]

    formatted_prompt = f"Write a solution to the following problem:\n```python\n{test_question}\n```\n"
    messages.append({"role": "user", "content": formatted_prompt})

    if model_name == "qwen3":
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
    else:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


def postprocess_completion(completion):
    completion = completion.replace("\r", "")
    completion = completion.strip()
    if '```python' in completion:
        print("completion matches ```python")
        def_line = completion.index('```python')
        completion = completion[def_line:].strip()
        completion = completion.replace('```python', '')
        try:
            next_line = completion.index('```')
            completion = completion[:next_line].strip()
        except:
            print("wrong completion")
    if "__name__ == \"__main__\"" in completion:
        print("completion matches __name__ == \"__main__\"")
        try:
            next_line = completion.index('if __name__ == "__main__":')
            completion = completion[:next_line].strip()
        except:
            print("wrong completion")
    if "# Example usage" in completion:
        print("completion matches # Example usage")
        next_line = completion.index('# Example usage')
        completion = completion[:next_line].strip()
    if "The solution is:" in completion:
        print("completion matches The solution is:")
        def_line = completion.index("The solution is:")
        completion = completion[def_line:].strip()
        completion = completion.replace('The solution is:', '')
        try:
            next_line = completion.index('\n\nThe answer is:')
            completion = completion[:next_line].strip()
        except:
            completion = completion.strip()
            print("maybe wrong completion")
    if "The answer is:" in completion:
        print("completion matches The answer is:")
        def_line = completion.index("The answer is:")
        completion = completion[def_line:].strip()
        completion = completion.replace('The answer is:', '')
        try:
            next_line = completion.index('\n\nThe answer is:')
            completion = completion[:next_line].strip()
        except:
            completion = completion.strip()
            print("maybe wrong completion")
    return completion

def generate_with_hf(args, model, tokenizer, prompts: List[str], generation_config: GenerationConfig, 
                    device: torch.device, num_samples: int) -> List[List[str]]:
    all_outputs = []
    if 'OLMoE-1B' in args.model_path:
        model_name = 'olmoe'
    elif 'Qwen3' in args.model_path:
        model_name = 'qwen3'
    elif 'DeepSeek' in args.model_path:
        model_name = 'deepseek'
    for prompt in tqdm(prompts, desc="Generating with HuggingFace"):
        completions = []
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        for _ in range(num_samples):
            try:
                if True:
                    outputs, delta_params = ttf_inference_single_fewshot_progressive_v2(
                        model,
                        tokenizer,
                        generation_config,
                        inputs,
                        steps=args.steps,
                        model_name=model_name,
                        ttf_lr=args.lr,
                        chunk_size=args.chunk_size
                    )
                    
                input_length = inputs["input_ids"].shape[1]
                generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
                completions.append(generated_text)
                
            except Exception as e:
                print(f"Generation error: {e}")
                completions.append("")
        
        all_outputs.append(completions)
    
    return all_outputs


def get_model_name(model_path: str) -> str:
    if 'OLMoE-1B' in model_path:
        return 'olmoe'
    elif 'Qwen3' in model_path:
        return 'qwen3'
    elif 'DeepSeek' in model_path:
        return 'deepseek'
    else:
        return 'unknown'


def main():
    parser = argparse.ArgumentParser(description="Generate k completions for HumanEval problems")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory for output files")
    parser.add_argument("--prefix", type=str, default="humaneval_generation", help="Output file prefix")
    parser.add_argument("--num_samples", type=int, default=3, help="Number of completions per problem")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    parser.add_argument("--start", type=int, help="Start index")
    parser.add_argument("--end", type=int, help="End index")
    parser.add_argument("--log_dir", type=str, required=True, help="Directory for log files")
    parser.add_argument("--steps", type=int, default=5, help="Maximum model length")
    
    parser.add_argument("--temperature", type=float, default=0.3, help="Generation temperature")
    parser.add_argument("--max_new_tokens", type=int, default=512, help="Max new tokens")
    parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
    parser.add_argument("--chunk_size", type=int, default=50, help="Top-p (nucleus) sampling parameter")
    parser.add_argument("--lr", type=float, default='0.005', help="Top-p (nucleus) sampling parameter")
    args = parser.parse_args()
    
    seed_everything(args.seed)
    
    method = "hf"
    output_dir = f"{args.output_dir}_{method}_num_samples{args.num_samples}"
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)    

    problems = read_problems()
    task_ids = sorted(problems.keys())
    
    if args.start is not None and args.end is not None:
        task_ids = task_ids[args.start:args.end]
    elif args.start is not None:
        task_ids = task_ids[args.start:]
    elif args.end is not None:
        task_ids = task_ids[:args.end]
    
    print(f"Processing {len(task_ids)} problems")
    
    model_name = get_model_name(args.model_path)
    print("Loading model with HuggingFace...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    generation_config = GenerationConfig(
        temperature=args.temperature,
        top_p=args.top_p,
        max_new_tokens=args.max_new_tokens,
        do_sample=args.temperature > 0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    generation_config.pad_token_id = tokenizer.pad_token_id
    generation_config.eos_token_id = tokenizer.eos_token_id
        
    print(f"Generation config: {generation_config}")
        
    prompts = []
    for task_id in task_ids:
        prompt = problems[task_id]['prompt'].replace('    ', '\t')
        chat_prompt = create_chat_prompt("You are a helpful assistant.", [], prompt, tokenizer, model_name)
        prompts.append(chat_prompt)

    log_file = os.path.join(args.log_dir, f"{args.prefix}_hf_num_samples_{args.num_samples}_batch1_log.txt")
    with open(log_file, "w") as f:
        f.write(f"Model Path: {args.model_path}\n")
        f.write(f"Inference method: HuggingFace Transformers\n")
        f.write(f"num_samples: {args.num_samples}\n")
        f.write(f"Batch size: 1\n")
        f.write(f"Device: {device}\n")
        f.write(f"Generation config: {generation_config}\n")
    
    all_completions = generate_with_hf(args, model, tokenizer, prompts, generation_config, device, args.num_samples)
    results = []
    for i, (task_id, completions) in enumerate(zip(task_ids, all_completions)):
        
        all_code = [completion.split("### Response:")[-1] for completion in completions]
        completions = [code.replace('\t', '    ') for code in all_code]
        
        processed_completions = [postprocess_completion(comp) for comp in completions]
        
        result = {
            'task_id': task_id,
            'prompt': problems[task_id]['prompt'],
            'completions': processed_completions,
            'raw_completions': all_code
        }
        results.append(result)
        
        if (i + 1) % 20 == 0:
            print(f"Processed {i+1}/{len(task_ids)} problems")
    
    output_file = os.path.join(output_dir, f"{args.prefix}_{args.start}_{args.end}_num_samples{args.num_samples}_results.json")
    write_jsonl(output_file, results)

    print(f"\nGeneration completed!")
    print(f"Total problems: {len(task_ids)}")
    print(f"Completions per problem: {args.num_samples}")
    print(f"Results saved to: {output_file}")


if __name__ == "__main__":
    main()