import torch
import requests
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import List, Dict, Any
import importlib
from openai import OpenAI
utils_path = "/workspace/rlhf-code/code/utils.py"
spec = importlib.util.spec_from_file_location("utils", utils_path)
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)

pad_to_length = utils.pad_to_length
get_reward = utils.get_reward


class GenerateModel:
    def __init__(self, gen_model_path, tokenizer_path, n_samples, max_new_tokens, top_k, top_p, temperature,  cache_dir = None):
        self.gen_model_path = gen_model_path
        self.is_instruct_model = 'instruct' in tokenizer_path.lower() or 'chat' in tokenizer_path.lower() or 'it' in tokenizer_path.lower()
        self.n_samples = n_samples
        self.max_new_tokens = max_new_tokens
        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, cache_dir = cache_dir)
        self.client = OpenAI(
            base_url="http://localhost:8000/v1",
            api_key="token-abc123",
        )


    
    def __call__(self, batch):
        batch = {k : v[0] for k, v in batch.items()}
        prompt = batch["prompt"]
        
        kwargs = {
                "model": self.gen_model_path,
                "max_tokens": len(self.tokenizer.encode(prompt)) + self.max_new_tokens,
                "temperature": self.temperature,
                "top_p": self.top_p,
                "n": self.n_samples
            }
        
        if self.is_instruct_model:
            kwargs["messages"] = [{"role": "user", "content": prompt}]
            completion = self.client.chat.completions.create(**kwargs)
            responses = [completion.choices[i].message.content for i in range(len(completion.choices))]
        else:
            kwargs["prompt"] = prompt
            completion = self.client.completions.create(**kwargs)
            responses = [completion.choices[i].text for i in range(len(completion.choices))]
            
        return [{'model_response': responses, **batch}]


    
class BoNModel:
    def __init__(self, reward_model_path, max_seq_len, batch_size, reward_chunk_size, cache_dir):
        self.reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_path, torch_dtype=torch.bfloat16, num_labels=1, cache_dir = cache_dir).to('cuda')
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size
        self.reward_chunk_size = reward_chunk_size
        self.tokenizer = AutoTokenizer.from_pretrained(reward_model_path, cache_dir = cache_dir)

    def __call__(self, batch):
        # 1) collect and pad your 3D output tensor (already done)
        generated_outputs_batch = batch['model_response'] 
        generated_outputs = torch.stack([
            torch.cat([pad_to_length(self.tokenizer(generated_outputs_batch[i][j], return_tensors='pt').input_ids, self.max_seq_len, pad_value=0) 
             for j in range(len(generated_outputs_batch[i]))])
            for i in range(len(generated_outputs_batch))
        ])
        
        # === START: reward‐scoring + selection ===
        # Move to the reward model's device
        device = next(self.reward_model.parameters()).device
        generated_outputs = generated_outputs.to(device)  # [B, N, L]
        B, N, L = generated_outputs.shape
        
        # Flatten to [B*N, L]
        flat_outputs = generated_outputs.view(B * N, L)
        
        all_chunk_rewards = []
        # Score in chunks of size reward_chunk_size
        for start in range(0, flat_outputs.size(0), self.reward_chunk_size):
            end = start + self.reward_chunk_size
            chunk = flat_outputs[start:end]  # [chunk_size, L]
            
            # get_reward returns (_, rewards, _)
            _, rewards, _ = get_reward(
                model=self.reward_model,
                query_responses=chunk,
                pad_token_id=0,
                context_length=0
            )
            all_chunk_rewards.append(rewards)
        
        # Concatenate back to a single vector of length B*N
        flat_rewards = torch.cat(all_chunk_rewards, dim=0)            # [B*N]
        rewards = flat_rewards.view(B, N)                             # [B, N]
        
        # Pick the index of the best reward for each item in the batch
        best_indices = torch.argmax(rewards, dim=1)                   # [B]
        
        # Gather the best output sequences: [B, L]
        best_outputs = generated_outputs[torch.arange(B), best_indices]
        
        # Decode into strings
        decoded = [
            self.tokenizer.decode(seq, skip_special_tokens=True)
            for seq in best_outputs
        ]

        res = []
        for i in range(len(decoded)): 
            itm = {k : v[i] for k, v in batch.items()}
            itm['model_response'] = decoded[i]
            res.append(itm)
        return res

