# File: helm/clients/proxy_tuning_client.py
from helm.clients.client import Client
from helm.common.cache import CacheConfig
from helm.tokenizers.tokenizer import Tokenizer
from helm.common.cache import Cache
from helm.common.request import Request, RequestResult, GeneratedOutput
from helm.proxy.retry import NonRetriableException


from typing import Optional, Dict, Any
import torch, os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from transformers.generation.utils import (
    ModelOutput,
)
from collections import defaultdict
import tqdm
from transformers import (
    LogitsProcessorList,
    BitsAndBytesConfig, 
)


import json
from datetime import datetime
from typing import List

from typing import Mapping            # used for cache_key type
from helm.clients.client import CachingClient  # used for make_cache_key
from helm.common.request import wrap_request_time  # used in make_request
from helm.common.cache import Cache
   
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

def _p_and_preds_for(next_token_logits: torch.Tensor, next_token: torch.Tensor):
    # next_token_logits: [1, V] on GPU, next_token: [1] on GPU
    with torch.no_grad():
        # prob of chosen token (scalar) and argmax token id (scalar)
        probs = next_token_logits.softmax(dim=-1)              # [1, V]
        p = probs.gather(-1, next_token.unsqueeze(-1)).squeeze(-1)   # [1]
        pred = next_token_logits.argmax(dim=-1)                # [1]
    return p.detach(), pred.detach()

class DExpertsLlama:
    def __init__(
        self,
        base_name: str,
        expert_name: str,
        antiexpert_name: str,
        tokenizer: AutoTokenizer,
        system_prompt: str = None,
        alpha: float = 1.0,
        chat_response_prefix: str = None,
        model_kwargs: Dict[str, Any] = None
    ):
        """
        chat_response_prefix: For llama chat models, it can be helpful for the response
        to start with a certain prefix to constrain the generation to directly answer
        the question. This makes evaluation on MC datasets easier.
        """

        self.base = AutoModelForCausalLM.from_pretrained(
            base_name, **model_kwargs
        )
        self.expert = AutoModelForCausalLM.from_pretrained(
            expert_name, **model_kwargs
        )
        self.antiexpert = AutoModelForCausalLM.from_pretrained(
            antiexpert_name, **model_kwargs
        )

        self.base.eval()
        self.expert.eval()
        self.antiexpert.eval()

        self.tokenizer = tokenizer
        self.alpha = alpha
        self.device = self.base.device
        self.chat_response_prefix = chat_response_prefix

        # Llama chat experts need different formatting
        self.use_chat_format_for_expert = True if 'chat' in expert_name.lower() else False
        self.use_chat_format_for_antiexpert = True if 'chat' in antiexpert_name.lower() else False
        self.use_chat_format_for_base = True if 'chat' in base_name.lower() else False

        if self.use_chat_format_for_expert or self.use_chat_format_for_antiexpert or self.use_chat_format_for_base:
            print("using chat format")
            # chat_prefix goes before the query, and chat_suffix goes after it
            self.chat_prefix = "[INST]"
            self.chat_suffix = "[/INST]"

            if system_prompt:
                self.chat_prefix += f"{B_SYS}{system_prompt}{E_SYS}"

            if self.chat_response_prefix:
                self.chat_suffix += f" {chat_response_prefix}"


    def forward(
        self,
        base_inputs,
        expert_inputs,
        antiexpert_inputs,
        return_dict=None
    ):
        base_outputs = self.base(**base_inputs, return_dict=return_dict)
        expert_outputs = self.expert(**expert_inputs, return_dict=return_dict)
        antiexpert_outputs = self.antiexpert(**antiexpert_inputs, return_dict=return_dict)

        return base_outputs, expert_outputs, antiexpert_outputs

    def _get_tokenized_chat_inputs(self, input_ids):
        """Decode input_ids and encode again to insert chat formatting"""

        prompts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # remove response_prefix (e.g., "Answer:") from the prompt if it's already there
        if self.chat_response_prefix:
            cleaned_prompts = []
            for p in prompts:
                if self.chat_response_prefix in p:
                    p = p.replace(self.chat_response_prefix, '').rstrip()
                cleaned_prompts.append(p)
        else:
            cleaned_prompts = prompts

        chat_prompts = [f'{self.chat_prefix} {p} {self.chat_suffix}' for p in cleaned_prompts]

        chat_inputs = self.tokenizer(
            chat_prompts, padding="longest", return_tensors="pt",
            add_special_tokens=True
        )
        chat_inputs.input_ids = chat_inputs.input_ids.to(self.device)
        chat_inputs.attention_mask = chat_inputs.attention_mask.to(self.device)

        return chat_inputs

    def update_analysis_data(self, analysis_data, next_tokens, next_token_logits_dict):
        analysis_data['tokens'].append([self.tokenizer.decode(t) for t in next_tokens])
        analysis_data['token_ids'].append(next_tokens)

        # logits from each model for the next token
        for model in next_token_logits_dict.keys():
            analysis_data[f'logits_{model}'].append(next_token_logits_dict[model].unsqueeze(dim=1))

        return analysis_data
    
    

    def generate(
        self,
        input_ids: Optional[torch.Tensor] = None,
        max_new_tokens: Optional[int] = 100,
        do_sample: bool = False,
        temperature: float = 1.0,
        logits_processor: Optional[LogitsProcessorList] = None,
        return_logits_for_analysis: bool = False,
        prefix_allowed_tokens_fn=None,
        **kwargs
    ):
        base_kwargs = kwargs.copy()
     
        expert_kwargs = kwargs.copy()
   
        antiexpert_kwargs = kwargs.copy()

        # prepare inputs for expert model
        if self.use_chat_format_for_expert:
            print("[generate] Using chat format for expert")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            expert_input_ids = chat_inputs.input_ids.to(input_ids.device)
            expert_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for expert")
            expert_input_ids = input_ids.to(input_ids.device)
        
        if self.use_chat_format_for_antiexpert:
            print("[generate] Using chat format for antiexpert")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            antiexpert_input_ids = chat_inputs.input_ids.to(input_ids.device)
            antiexpert_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for antiexpert")
            antiexpert_input_ids = input_ids.to(input_ids.device)
        
        if self.use_chat_format_for_base:
            print("[generate] Using chat format for base")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            base_input_ids = chat_inputs.input_ids.to(input_ids.device)
            base_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for base")
            base_input_ids = input_ids.to(input_ids.device)

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
      

        eos_token_id_tensor = torch.tensor([self.tokenizer.eos_token_id]).to(input_ids.device)

        
#         if return_logits_for_analysis:
#             analysis_data = defaultdict(list)
#             print("[generate] Initialized analysis_data: ", analysis_data)

        B = input_ids.shape[0]  # always 1
        T = max_new_tokens
        if return_logits_for_analysis:
            device = input_ids.device
            # 1 x T buffers on GPU
            p_dexperts = torch.empty(T, device=device, dtype=torch.float16)
            p_base     = torch.empty(T, device=device, dtype=torch.float16)
            p_expert   = torch.empty(T, device=device, dtype=torch.float16)
            p_anti     = torch.empty(T, device=device, dtype=torch.float16)

            preds_dexperts = torch.empty(T, device=device, dtype=torch.int32)
            preds_base     = torch.empty(T, device=device, dtype=torch.int32)
            preds_expert   = torch.empty(T, device=device, dtype=torch.int32)
            preds_anti     = torch.empty(T, device=device, dtype=torch.int32)

            token_ids_out  = torch.empty(T, device=device, dtype=torch.int32)
            t_write = 0

        for step in range(max_new_tokens):
            # prepare model inputs with past_key_values and attention_mask
            base_inputs = self.base.prepare_inputs_for_generation(base_input_ids, **base_kwargs)
            expert_inputs = self.expert.prepare_inputs_for_generation(expert_input_ids, **expert_kwargs)
            antiexpert_inputs = self.antiexpert.prepare_inputs_for_generation(antiexpert_input_ids, **antiexpert_kwargs)

            # DExperts
            base_outputs, expert_outputs, antiexpert_outputs = self.forward(
                base_inputs, expert_inputs, antiexpert_inputs, return_dict=True
            )

            base_next_token_logits = base_outputs.logits[..., -1, :]

            expert_next_token_logits = expert_outputs.logits[..., -1, :]
   
            antiexpert_next_token_logits = antiexpert_outputs.logits[..., -1, :]
        

            # sometimes our experts have extra (irrelevant) tokens at the end of the normal vocabulary
            expert_next_token_logits = expert_next_token_logits[:, :base_next_token_logits.shape[-1]]
            # print ("expert_next_token_logits: ", expert_next_token_logits)

            # DExperts! # this is the actual proxy tuning part
            next_token_logits = (
                base_next_token_logits +
                self.alpha * (expert_next_token_logits - antiexpert_next_token_logits)
            )
     
        
            if step < 2:                                       
                next_token_logits[:, self.tokenizer.eos_token_id] = -float("inf")
                
            
            # logits_processor = LogitsProcessorList([MinLengthLogitsProcessor(min_length=2, eos_token_id=self.tokenizer.eos_token_id)])

            # pre-process logits
            if logits_processor:
                next_token_logits = logits_processor(base_input_ids, next_token_logits)
                # print("logits_processor")
  
            # warp logits
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
                # print("temperature")
            

            # decode
            if do_sample:
                probs = F.softmax(next_token_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                # print("do_sample")
            else:
                next_tokens = torch.argmax(next_token_logits, dim=-1)
                # print("next_tokens: ", next_tokens)

            next_tokens = (
                next_tokens * unfinished_sequences +
                self.tokenizer.pad_token_id * (1 - unfinished_sequences)
            )

            # if return_logits_for_analysis:
            #     next_token_logits_dict = {
            #         'dexperts': next_token_logits,
            #         'base': base_next_token_logits,
            #         'expert': expert_next_token_logits,
            #         'antiexpert': antiexpert_next_token_logits
            #     }
            #     analysis_data = self.update_analysis_data(analysis_data, next_tokens, next_token_logits_dict)
            
            if return_logits_for_analysis:
                p_dexp, pred_dexp = _p_and_preds_for(next_token_logits, next_tokens)
                p_b,    pred_b    = _p_and_preds_for(base_next_token_logits,   next_tokens)
                p_e,    pred_e    = _p_and_preds_for(expert_next_token_logits, next_tokens)
                p_a,    pred_a    = _p_and_preds_for(antiexpert_next_token_logits, next_tokens)

                # write scalars to 1D slots
                p_dexperts[t_write] = p_dexp[0].to(torch.float16)
                p_base[t_write]     = p_b[0].to(torch.float16)
                p_expert[t_write]   = p_e[0].to(torch.float16)
                p_anti[t_write]     = p_a[0].to(torch.float16)

                preds_dexperts[t_write] = pred_dexp[0].to(torch.int32)
                preds_base[t_write]     = pred_b[0].to(torch.int32)
                preds_expert[t_write]   = pred_e[0].to(torch.int32)
                preds_anti[t_write]     = pred_a[0].to(torch.int32)

                token_ids_out[t_write]  = next_tokens[0].to(torch.int32)
                t_write += 1

            # update model inputs for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            expert_input_ids = torch.cat([expert_input_ids, next_tokens[:, None]], dim=-1)
            antiexpert_input_ids = torch.cat([antiexpert_input_ids, next_tokens[:, None]], dim=-1)
            base_input_ids = torch.cat([base_input_ids, next_tokens[:, None]], dim=-1)

            # update kwargs
            base_kwargs = self._update_model_kwargs_for_generation(base_outputs, base_kwargs)
            expert_kwargs = self._update_model_kwargs_for_generation(expert_outputs, expert_kwargs)
            antiexpert_kwargs = self._update_model_kwargs_for_generation(antiexpert_outputs, antiexpert_kwargs)

            # if eos_token was found in one sentence, set sentence to finished
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )

            # stop when each sentence is finished
            if unfinished_sequences.max() == 0:
                break

        # if return_logits_for_analysis:
        #     for k in analysis_data.keys():
        #         if k.startswith('logits'):
        #             analysis_data[k] = torch.cat(analysis_data[k], dim=1)
        #     return input_ids, analysis_data
        
        if return_logits_for_analysis:
            sl = slice(0, t_write)
            results = [{
                'token_ids':        token_ids_out[sl],     # [T’] int32 (GPU)
                'p_dexperts':       p_dexperts[sl],        # [T’] fp16  (GPU)
                'preds_dexperts':   preds_dexperts[sl],    # [T’] int32 (GPU)
                'p_base':           p_base[sl],
                'preds_base':       preds_base[sl],
                'p_expert':         p_expert[sl],
                'preds_expert':     preds_expert[sl],
                'p_antiexpert':     p_anti[sl],
                'preds_antiexpert': preds_anti[sl],
                # (optional) decode later if you want strings
            }]
            return input_ids, results

        return input_ids

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        kwargs: Dict[str, Any],
    ) -> Dict[str, Any]:
        # update past_key_values
        kwargs["past_key_values"] = outputs.past_key_values

        # update attention mask
        if "attention_mask" in kwargs:
            attention_mask = kwargs["attention_mask"]
            kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )
        if getattr(outputs, "cache_position", None) is not None:
        # some models already return it
            kwargs["cache_position"] = outputs.cache_position
        else:
            if "cache_position" in kwargs:
                kwargs["cache_position"] = kwargs["cache_position"] + 1
            else:
                # first step: position is sequence-length-1
                seq_len = kwargs["attention_mask"].shape[1]
                kwargs["cache_position"] = torch.arange(seq_len - 1, seq_len, device=kwargs["attention_mask"].device)

        return kwargs

class RegularLlama:
    def __init__(
        self,
        base_name: str,
        tokenizer: AutoTokenizer,
        system_prompt: str = None,
        alpha: float = 1.0,
        chat_response_prefix: str = None,
        model_kwargs: Dict[str, Any] = None
    ):
        """
        chat_response_prefix: For llama chat models, it can be helpful for the response
        to start with a certain prefix to constrain the generation to directly answer
        the question. This makes evaluation on MC datasets easier.
        """

        self.base = AutoModelForCausalLM.from_pretrained(
            base_name, **model_kwargs
        )
       

        self.base.eval()
      

        self.tokenizer = tokenizer
        self.alpha = alpha
        self.device = self.base.device
        self.chat_response_prefix = chat_response_prefix

        # Llama chat experts need different formatting
      
        self.use_chat_format_for_base = True if 'chat' in base_name.lower() else False

        if self.use_chat_format_for_base:
            print("using chat format")
            # chat_prefix goes before the query, and chat_suffix goes after it
            self.chat_prefix = "[INST]"
            self.chat_suffix = "[/INST]"

            if system_prompt:
                self.chat_prefix += f"{B_SYS}{system_prompt}{E_SYS}"

            if self.chat_response_prefix:
                self.chat_suffix += f" {chat_response_prefix}"


    def forward(
        self,
        base_inputs,
        return_dict=None
    ):
        base_outputs = self.base(**base_inputs, return_dict=return_dict)

        return base_outputs

    
    def _get_tokenized_chat_inputs(self, input_ids):
        """Decode input_ids and encode again to insert chat formatting"""

        prompts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # remove response_prefix (e.g., "Answer:") from the prompt if it's already there
        if self.chat_response_prefix:
            cleaned_prompts = []
            for p in prompts:
                if self.chat_response_prefix in p:
                    p = p.replace(self.chat_response_prefix, '').rstrip()
                cleaned_prompts.append(p)
        else:
            cleaned_prompts = prompts

        chat_prompts = [f'{self.chat_prefix} {p} {self.chat_suffix}' for p in cleaned_prompts]

        chat_inputs = self.tokenizer(
            chat_prompts, padding="longest", return_tensors="pt",
            add_special_tokens=True
        )
        chat_inputs.input_ids = chat_inputs.input_ids.to(self.device)
        chat_inputs.attention_mask = chat_inputs.attention_mask.to(self.device)

        return chat_inputs
    
    def update_analysis_data(self, analysis_data, next_tokens, next_token_logits_dict):
        analysis_data['tokens'].append([self.tokenizer.decode(t) for t in next_tokens])
        analysis_data['token_ids'].append(next_tokens)

        # logits from each model for the next token
        for model in next_token_logits_dict.keys():
            analysis_data[f'logits_{model}'].append(next_token_logits_dict[model].unsqueeze(dim=1))

        return analysis_data
    
    

    def generate(
        self,
        input_ids: Optional[torch.Tensor] = None,
        max_new_tokens: Optional[int] = 100,
        do_sample: bool = False,
        temperature: float = 1.0,
        logits_processor: Optional[LogitsProcessorList] = None,
        return_logits_for_analysis: bool = False,
        prefix_allowed_tokens_fn=None,
        **kwargs
    ):
        base_kwargs = kwargs.copy()
     
       

        # prepare inputs for  model
        
        if self.use_chat_format_for_base:
            print("[generate] Using chat format for base")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            base_input_ids = chat_inputs.input_ids.to(input_ids.device)
            base_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for base")
            base_input_ids = input_ids.to(input_ids.device)

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
      

        eos_token_id_tensor = torch.tensor([self.tokenizer.eos_token_id]).to(input_ids.device)


        if return_logits_for_analysis:
            analysis_data = defaultdict(list)
            print("[generate] Initialized analysis_data: ", analysis_data)


        for step in range(max_new_tokens):
            # prepare model inputs with past_key_values and attention_mask
            base_inputs = self.base.prepare_inputs_for_generation(base_input_ids, **base_kwargs)
            
            # DExperts
            base_outputs = self.forward(
                base_inputs, return_dict=True
            )

            base_next_token_logits = base_outputs.logits[..., -1, :]

          
        

            # DExperts! # this is the actual proxy tuning part
            next_token_logits = base_next_token_logits 
     
        
            if step < 2:                                       
                next_token_logits[:, self.tokenizer.eos_token_id] = -float("inf")
                
            
            # logits_processor = LogitsProcessorList([MinLengthLogitsProcessor(min_length=2, eos_token_id=self.tokenizer.eos_token_id)])

            # pre-process logits
            if logits_processor:
                next_token_logits = logits_processor(base_input_ids, next_token_logits)
                # print("logits_processor")

            # warp logits
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
                # print("temperature")
            

            # decode
            if do_sample:
                probs = F.softmax(next_token_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                # print("do_sample")
            else:
                next_tokens = torch.argmax(next_token_logits, dim=-1)
                # print("next_tokens: ", next_tokens)

            next_tokens = (
                next_tokens * unfinished_sequences +
                self.tokenizer.pad_token_id * (1 - unfinished_sequences)
            )

            # if return_logits_for_analysis:
            #     next_token_logits_dict = {
            #         'base': base_next_token_logits,
            #     }
            #     analysis_data = self.update_analysis_data(analysis_data, next_tokens, next_token_logits_dict)

            # update model inputs for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            base_input_ids = torch.cat([base_input_ids, next_tokens[:, None]], dim=-1)

            # update kwargs
            base_kwargs = self._update_model_kwargs_for_generation(base_outputs, base_kwargs)
            # if eos_token was found in one sentence, set sentence to finished
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )

            # stop when each sentence is finished
            if unfinished_sequences.max() == 0:
                break

        # if return_logits_for_analysis:
        #     for k in analysis_data.keys():
        #         if k.startswith('logits'):
        #             analysis_data[k] = torch.cat(analysis_data[k], dim=1)
        #     return input_ids, analysis_data
        
        if return_logits_for_analysis:
            return input_ids, []

        return input_ids

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        kwargs: Dict[str, Any],
    ) -> Dict[str, Any]:
        # update past_key_values
        kwargs["past_key_values"] = outputs.past_key_values

        # update attention mask
        if "attention_mask" in kwargs:
            attention_mask = kwargs["attention_mask"]
            kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )
        if getattr(outputs, "cache_position", None) is not None:
        # some models already return it
            kwargs["cache_position"] = outputs.cache_position
        else:
            if "cache_position" in kwargs:
                kwargs["cache_position"] = kwargs["cache_position"] + 1
            else:
                # first step: position is sequence-length-1
                seq_len = kwargs["attention_mask"].shape[1]
                kwargs["cache_position"] = torch.arange(seq_len - 1, seq_len, device=kwargs["attention_mask"].device)

        return kwargs


def ensure_dir(d):
    if not os.path.exists(d):
        os.makedirs(d, exist_ok=True)


@torch.inference_mode()
def generate_completions(
    model,
    tokenizer,
    prompts,
    batch_size=1,
    add_special_tokens=True,
    disable_tqdm=False,
    temperature=1.0,
    logits_processor=None,
    prefix_allowed_tokens_fn=None,
    return_logits_for_analysis=False,
    **generation_kwargs, 
    
):
    generations = []
    outputs = []
    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")

    num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
     
    all_results = []
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        tokenized_prompts = tokenizer(
            batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens
        )
        
        # print ("tokenized_prompt: ", tokenized_prompts)
        if hasattr(model, "device"):                 # DExpertsLlama
            device = model.device
            # print ("device = model.device")
        else:                                        # vanilla HF model
            device = next(model.parameters()).device
            # print ("next(model.parameters()).devicedevice = next(model.parameters()).device")
        batch_input_ids = tokenized_prompts['input_ids'].to(device)
        attention_mask = tokenized_prompts['attention_mask'].to(device)

        if return_logits_for_analysis:
            batch_outputs, results = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                logits_processor=logits_processor,
                prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                temperature=temperature,
                return_logits_for_analysis=return_logits_for_analysis,
                **generation_kwargs
            )

            # results = flatten_batch_results(results)
            # shortened_results = summarize_results(results)
            all_results.extend(results)
        
        else:
            batch_outputs = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                logits_processor=logits_processor,
                prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                temperature=temperature,
                **generation_kwargs
            )
            results = []
        
        # to support the logits processing below when using DExperts with mixed tokenizers
        if isinstance(batch_input_ids, dict):
            batch_input_ids = batch_input_ids['llama']

    
        batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
        print("batch_outputs: ", batch_outputs)
        batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)

        # duplicate the prompts to match the number of return sequences
        batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
        batch_generations = [
            output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
        ]
       

        generations += batch_generations
        outputs += batch_outputs

        if not disable_tqdm:
            progress.update(len(batch_prompts)//num_return_sequences)
    
    

    assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
    assert len(outputs) == len(prompts) * num_return_sequences, "number of outputs should be equal to number of prompts * num_return_sequences"
    # return generations, logits_for_analysis
    return outputs, generations, all_results


def add_pad_token(tokenizer, padding_side="left"):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = padding_side
    return tokenizer

def load_dexperts_model_and_tokenizer(
    base_name: str,
    expert_name: str,
    antiexpert_name: str,
    tokenizer_name: str,
    device_map: str = "auto",
    system_prompt: str = None,
    alpha: float = 1.0,
    chat_response_prefix: str = None,
    load_in_8bit: bool = False,
    load_in_4bit: bool = False,
    use_fast_tokenizer: bool = True,
    padding_side: str = "left",
    proxy_tune: bool = False,
):
    
    bnb_cfg = None

    if load_in_8bit:
        bnb_cfg = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
    
    if load_in_4bit:
        bnb_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",         # {nf4, fp4}; nf4 is standard
            bnb_4bit_compute_dtype=torch.bfloat16,  # or torch.float16 if BF16 isn’t ideal
        )

    model_kwargs = {
        'device_map': device_map,
        'torch_dtype': torch.float16,
        'quantization_config': bnb_cfg,
        'low_cpu_mem_usage': True,
    }

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATHS[tokenizer_name], use_fast=use_fast_tokenizer)
    print(f"[Loader] Tokenizer   : {MODEL_PATHS[tokenizer_name]}")
    tokenizer = add_pad_token(tokenizer, padding_side)
   
    if proxy_tune:
        model = DExpertsLlama(base_name=MODEL_PATHS[base_name], expert_name=MODEL_PATHS[expert_name], antiexpert_name=MODEL_PATHS[antiexpert_name], tokenizer=tokenizer, system_prompt=system_prompt, alpha=alpha, chat_response_prefix=chat_response_prefix, model_kwargs=model_kwargs)
        print(f"[Loader] Base model   : {MODEL_PATHS[base_name]}")
        print(f"[Loader] Expert model : {MODEL_PATHS[expert_name]}")
        print(f"[Loader] Anti‑expert  : {MODEL_PATHS[antiexpert_name]}")
    else:
        model = RegularLlama(base_name=MODEL_PATHS[base_name], tokenizer=tokenizer, system_prompt=system_prompt, alpha=alpha, chat_response_prefix=chat_response_prefix, model_kwargs=model_kwargs)
        print(f"[Loader] Base model   : {MODEL_PATHS[base_name]}")

    return model, tokenizer
    # del model
    # del tokenizer
    # import gc
    # gc.collect()
    # torch.cuda.empty_cache()


def _safe_tag(model_name: str) -> str:
    # e.g. "proxy_tuning/llama70b_mellama13bchat" -> "proxy_tuning_llama70b_mellama13bchat"
    return model_name.replace("/", "_").replace(" ", "").replace(".", "").replace("-", "")

def setup_run_dirs(model_name: str, root=""):
    """
    Creates:
      <root>/<TAG>_<YYYYMMDD_HHMMSS>/
          ├─ <TAG>_<YYYYMMDD_HHMMSS>.csv
          └─ logits_analysis/
    Returns: (run_dir, csv_path, logits_dir)
    """
    ensure_dir(root)
    tag = _safe_tag(model_name)
    stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(root, f"{tag}_{stamp}")
    ensure_dir(run_dir)

    csv_name = f"{tag}_{stamp}.csv"
    csv_path = os.path.join(run_dir, csv_name)
    with open(csv_path, "w") as f:
        f.write("timestamp,request_id,model_name,prompt,output,logits_path\n")

    logits_dir = os.path.join(run_dir, "logits_analysis")
    ensure_dir(logits_dir)

    print(f"[TokenLog] created run dir: {run_dir}")
    print(f"[TokenLog] csv: {csv_path}")
    print(f"[TokenLog] logits dir: {logits_dir}")
    return run_dir, csv_path, logits_dir

# def append_token_record(model_name: str, dir_path=""):
#     ensure_dir(dir_path)
#     tag = model_name.split("/")[-1].replace(" ", "").replace(".", "").replace("-", "")
#     fname_ts = datetime.now().strftime("%m%d%y%H%M%S")
#     filename = f"{tag}_{fname_ts}.csv"
#     path = os.path.join(dir_path, filename)
#     with open(path, "w") as f:
#         f.write("timestamp,request_id,model_name,prompt,output,logits_path\n")
#     print(f"[TokenLog] created {path}")
#     return path

def append_request_row(csv_path: str, request_id: str, model_name: str, prompt: str, output: str, logits_path: str | None):
    ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    def esc(s: str) -> str:
        if s is None: return ""
        return s.replace("\n", "\\n").replace(",", "&#44;")
    with open(csv_path, "a") as f:
        f.write(f"{ts},{request_id},{esc(model_name)},{esc(prompt)},{esc(output)},{esc(logits_path or '')}\n")


class ProxyTuningClient(Client):
    """
    A HELM client that uses ProxyTuning for inference instead of directly calling the model.
    """

    def __init__(
        self,
        tokenizer: Tokenizer,
        tokenizer_name: str,
        cache_config: CacheConfig,
        model_name: str = None,
        api_base: str = None,
        api_key: str = None,
    ):
        self.cache = Cache(cache_config)
        """
        Initializes the ProxyTuningClient.

        Args:
            tokenizer (Tokenizer): Tokenizer instance (unused but required by HELM interface).
            tokenizer_name (str): Name of the tokenizer (unused but required by HELM interface).
            cache_config (CacheConfig): Configuration for caching.

        """
        # i want to create a folder called model_datetime in 
        # create a csv file w the same name (we already do this)
        # create a subfolder of the model_datetime folder called logits_analysis
        # put every .pt file into the logits_analysis folder 
        # self.token_log_path = append_token_record(model_name)
        # self.model_name = model_name
        # self.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        # self.req_seq = 0  
        self.run_dir, self.token_log_path, self.logits_dir = setup_run_dirs(model_name)
        self.model_name = model_name
        self.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.req_seq = 0

        if model_name == "proxy_tuning/llama70b_mellama13bchat": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="llama-70b-chat",
                expert_name="mellama-13b-chat",
                antiexpert_name="llama-13b-base",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=True,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=True,

            )
        elif model_name == "proxy_tuning/llama70b_mellama13bbase": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="llama-70b-chat",
                expert_name="mellama-13b-base",
                antiexpert_name="llama-13b-base",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=True,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=True,

            )
        elif model_name == "proxy_tuning/llama70b_mellama13bchat_base": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="llama-70b-chat",
                expert_name="mellama-13b-chat",
                antiexpert_name="mellama-13b-base",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=True,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=True,

            )
        elif model_name == "proxy_tuning/llama70b": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="llama-70b-chat",
                expert_name="none",
                antiexpert_name="none",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=True,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=False,

            )
        elif model_name == "proxy_tuning/mellama70b": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="mellama-70b-chat",
                expert_name="none",
                antiexpert_name="none",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=True,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=False,

            )
        elif model_name == "proxy_tuning/mellama13bbase": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="mellama-13b-base",
                expert_name="none",
                antiexpert_name="none",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=False,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=False,

            )
        elif model_name == "proxy_tuning/mellama13bchat": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="mellama-13b-chat",
                expert_name="none",
                antiexpert_name="none",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=False,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=False,

            )
        elif model_name == "proxy_tuning/llama7b": 
            self.model, self.hf_tokenizer = load_dexperts_model_and_tokenizer(
                base_name="llama-7b-base",
                expert_name="none",
                antiexpert_name="none",
                tokenizer_name="llama-7b-base",
                load_in_8bit=False,
                load_in_4bit=True,
                use_fast_tokenizer=True,
                device_map='auto', 
                proxy_tune=False,

            )
        else:
            raise NonRetriableException(f"Unknown model_name route: {model_name}")

        self.is_proxy = isinstance(self.model, DExpertsLlama)


    def make_request(self, request: Request) -> RequestResult:
        """
        Handles a request by sending the prompt 

        Args:
            request (Request): The request object containing the prompt.

        Returns:
            RequestResult: A HELM-compatible response object.
        """
        prompt_text = request.prompt

        if request.messages:
            prompt_text = " ".join(msg["content"] for msg in request.messages if msg.get("role") != "system")

        
        print("prompt_text: ", prompt_text)
        prompts = [prompt_text]
         # turn prompt into a [] 
        outputs, predicted_labels, all_results = generate_completions(
            model=self.model,
            tokenizer=self.hf_tokenizer,
            prompts=prompts,
            max_new_tokens=600,       
            do_sample=False,        
            logits_processor=None, 
            prefix_allowed_tokens_fn=None,
            num_return_sequences=1,
            return_logits_for_analysis=self.is_proxy, 
        )
        output_text = predicted_labels[0]
        print("output_text: ", output_text)
        
        self.req_seq += 1
        request_id = f"{self.run_id}_r{self.req_seq:04d}"

        logits_path = None
        if self.is_proxy and all_results:
            logits_path = os.path.join(self.logits_dir, f"logits_{request_id}.pt")
            torch.save(all_results, logits_path)
            print(f"[Logits] wrote {logits_path}")

        append_request_row(
            csv_path=self.token_log_path,
            request_id=request_id,
            model_name=self.model_name,
            prompt=prompt_text,
            output=output_text,
            logits_path=logits_path,
        )
        
        # Return a HELM-compatible RequestResult
        output = GeneratedOutput(text=output_text, logprob=0.0, tokens=[])
        return RequestResult(success=True, cached=False, completions=[output], embedding=[])
#     def _get_messages_from_request(self, request: Request) -> List[Dict]:
#         if request.prompt and request.messages:
#             raise ValueError(f"Only one of `prompt` and `messages` may be set in request: {request}")
#         if request.multimodal_prompt:
#             raise ValueError("`multimodal_prompt` is not supported by ProxyTuningClient")
#         if request.messages:
#             return [{"role": message["role"], "content": message["content"]} for message in request.messages]
#         else:
#             return [{"role": "user", "content": request.prompt}]
        
#     def _convert_request_to_raw_request(self, request: Request) -> Dict:
#         raw_request = {
#             "messages": self._get_messages_from_request(request),
#             "model": request.model.split("/")[-1],
#             "logprobs": bool(request.top_k_per_token),
#             "max_tokens": request.max_tokens,
#             "n": request.num_completions,
#             "stop": request.stop_sequences,
#             "temperature": request.temperature,
#             "top_p": request.top_p,
#         }
#         if request.response_format and request.response_format.json_schema:
#             raw_request["response_format"] = {
#                 "type": "json_schema",
#                 "json_schema": {
#                     "schema": request.response_format.json_schema,
#                 },
#             }
#         return raw_request

    
#     def make_request(self, request: Request) -> RequestResult:
#         raw_request = self._convert_request_to_raw_request(request)
#         cache_key: Mapping = CachingClient.make_cache_key(raw_request, request)

#         def do_it() -> Dict[Any, Any]:
#             prompt_text = request.prompt

#             if request.messages:
#                 prompt_text = " ".join(msg["content"] for msg in request.messages if msg.get("role") != "system")

        
#             print("prompt_text: ", prompt_text)
#             prompts = [prompt_text]
#              # turn prompt into a [] 
#             outputs, predicted_labels, all_results = generate_completions(
#                 model=self.model,
#                 tokenizer=self.hf_tokenizer,
#                 prompts=prompts,
#                 max_new_tokens=600,       
#                 do_sample=False,        
#                 logits_processor=None, 
#                 prefix_allowed_tokens_fn=None,
#                 num_return_sequences=1,
#                 return_logits_for_analysis=self.is_proxy, 
#             )
#             output_text = predicted_labels[0]
#             print("output_text: ", output_text)

#             self.req_seq += 1
#             request_id = f"{self.run_id}_r{self.req_seq:04d}"

#             logits_path = None
#             if self.is_proxy and all_results:
#                 logits_path = os.path.join(self.logits_dir, f"logits_{request_id}.pt")
#                 torch.save(all_results, logits_path)
#                 print(f"[Logits] wrote {logits_path}")

#             append_request_row(
#                 csv_path=self.token_log_path,
#                 request_id=request_id,
#                 model_name=self.model_name,
#                 prompt=prompt_text,
#                 output=output_text,
#                 logits_path=logits_path,
#             )
#             return {"output_text": output_text}
        
#         try:
#             raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
#         except Exception as error:
#             return RequestResult(
#                 success=False,
#                 cached=False,
#                 error=str(error),
#                 completions=[],
#                 embedding=[],
#             )

#         output = GeneratedOutput(text=raw_response["output_text"], logprob=0.0, tokens=[])

#         return RequestResult(
#             success=True,
#             cached=cached,
#             request_time=raw_response["request_time"],
#             request_datetime=raw_response["request_datetime"],
#             completions=[output],
#             embedding=[],
#         )

    