from typing import Optional, Dict, Any
import torch, os, re, time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from transformers.generation.utils import (
    ModelOutput,
)
from collections import defaultdict
import tqdm
from transformers import (
    LogitsProcessorList,
    BitsAndBytesConfig, 
)

import argparse
import json
import random
import pandas as pd

from sklearn.metrics import accuracy_score, f1_score
from datetime import datetime
from pydantic import BaseModel
from lmformatenforcer import JsonSchemaParser
from typing import List
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn, build_token_enforcer_tokenizer_data
import math
   
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

MODEL_PATHS = {
    "llama-70b-chat":  "/models/meta-llama/Llama-2-70b-chat-hf",
    "llama-13b-base": "/models/meta-llama/Llama-2-13b-hf",
    "llama-7b-base": "/models/meta-llama/Llama-2-7b-hf",
    "mellama-13b-chat": "/models/me-llama/MeLLaMA-13B-chat",
    "mellama-13b-base": "/models/me-llama/MeLLaMA-13B", 
    "mellama-70b-chat": "/models/me-llama/MeLLaMA-70B-chat",
}

#should be a jsonl file that has task_id, prompt, gold

DATASET_PATHS = {
    "mednli": "/proxy_tuning/datasets/mednli/dev_reason_first.jsonl",
    "mtsample": "/proxy_tuning/datasets/mtsample/val_reason_first.jsonl",
    "fall_prediction": "/proxy_tuning/datasets/falls_prediction/val_reason_first.jsonl",
}
LABELS = {
"mednli": ['entailment', 'contradiction', 'neutral'],
    "mtsample": ['Surgery', 'Allergy / Immunology', 'Sleep Medicine', 'Pediatrics - Neonatal', 'SOAP / Chart / Progress Notes', 'Bariatrics', 'Pain Management', 'Lab Medicine - Pathology', 'Dermatology', 'Orthopedic', 'Dentistry', 'Psychiatry / Psychology', 'General Medicine', 'Office Notes', 'Letters', 'Neurosurgery', 'Radiology', 'Cosmetic / Plastic Surgery', 'Nephrology', 'Diets and Nutritions', 'Chiropractic', 'Gastroenterology', 'Cardiovascular / Pulmonary', 'Speech - Language', 'Hospice - Palliative Care', 'Autopsy', 'Endocrinology', 'Emergency Room Reports', 'Discharge Summary', 'ENT - Otolaryngology', 'Urology', 'Physical Medicine - Rehab', 'Neurology', 'Podiatry', 'Ophthalmology', 'Rheumatology', 'IME-QME-Work Comp etc.', 'Hematology - Oncology', 'Consult - History and Phy.', 'Obstetrics / Gynecology'], 
    "fall_prediction": ['fall', 'no fall']
}


def flatten_batch_results(batch):
    """
    Flatten batch results into a list of results for each prompt
    """
    all_results = []
    batch_size = len(batch['tokens'][0])
    for i in range(batch_size):
        ex = {}
        ex['tokens'] = [x[i] for x in batch['tokens']]  # list of tokens
        if '</s>' in ex['tokens']:
            output_len = ex['tokens'].index('</s>')
        else:
            output_len = len(ex['tokens'])
        # ex['token_ids'] = [x.squeeze(dim=0)[i].item() for x in batch['token_ids']][:output_len]  # list of tokens
        ex['token_ids'] = [x[i].item() for x in batch['token_ids']][:output_len]  # list of tokens
        ex['tokens'] = ex['tokens'][:output_len]
        for k in batch.keys():
            if k.startswith('logits'):
                ex[k] = batch[k][i, ...][:output_len, ...]
        all_results.append(ex)
    return all_results


def summarize_results(results):
    """
    Logit vectors are huge, so let's just extract the key information: the probability of the 
    DExperts next-token and the top prediction from each model.
    """
    shortened_results = []
    logit_keys = [k for k in results[0].keys() if k.startswith('logits')]
    for ex in results:
        for k in logit_keys:
            model = '_'.join(k.split('_')[1:])
            probs = ex[k].softmax(dim=-1)
            ex[f'p_{model}'] = probs.gather(-1, torch.tensor(ex['token_ids']).unsqueeze(-1).cuda()).squeeze()
            ex[f'preds_{model}'] = ex[k].argmax(dim=-1)
            del ex[k]

        shortened_results.append(ex)
    return shortened_results




class DExpertsLlama:
    def __init__(
        self,
        base_name: str,
        expert_name: str,
        antiexpert_name: str,
        tokenizer: AutoTokenizer,
        system_prompt: str = None,
        alpha: float = 1.0,
        chat_response_prefix: str = None,
        model_kwargs: Dict[str, Any] = None
    ):
        """
        chat_response_prefix: For llama chat models, it can be helpful for the response
        to start with a certain prefix to constrain the generation to directly answer
        the question. This makes evaluation on MC datasets easier.
        """

        self.base = AutoModelForCausalLM.from_pretrained(
            base_name, **model_kwargs
        )
        self.expert = AutoModelForCausalLM.from_pretrained(
            expert_name, **model_kwargs
        )
        self.antiexpert = AutoModelForCausalLM.from_pretrained(
            antiexpert_name, **model_kwargs
        )

        self.base.eval()
        self.expert.eval()
        self.antiexpert.eval()

        self.tokenizer = tokenizer
        self.alpha = alpha
        print("alpha is: ", alpha)
        self.device = self.base.device
        self.chat_response_prefix = chat_response_prefix

        # Llama chat experts need different formatting
        self.use_chat_format_for_expert = True if 'chat' in expert_name.lower() else False
        self.use_chat_format_for_antiexpert = True if 'chat' in antiexpert_name.lower() else False
        self.use_chat_format_for_base = True if 'chat' in base_name.lower() else False

        if self.use_chat_format_for_expert or self.use_chat_format_for_antiexpert or self.use_chat_format_for_base:
            print("using chat format")
            # chat_prefix goes before the query, and chat_suffix goes after it
            self.chat_prefix = "[INST]"
            self.chat_suffix = "[/INST]"

            if system_prompt:
                self.chat_prefix += f"{B_SYS}{system_prompt}{E_SYS}"

            if self.chat_response_prefix:
                self.chat_suffix += f" {chat_response_prefix}"


    def forward(
        self,
        base_inputs,
        expert_inputs,
        antiexpert_inputs,
        return_dict=None
    ):
        base_outputs = self.base(**base_inputs, return_dict=return_dict)
        expert_outputs = self.expert(**expert_inputs, return_dict=return_dict)
        antiexpert_outputs = self.antiexpert(**antiexpert_inputs, return_dict=return_dict)

        return base_outputs, expert_outputs, antiexpert_outputs

    def _get_tokenized_chat_inputs(self, input_ids):
        """Decode input_ids and encode again to insert chat formatting"""

        prompts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # remove response_prefix (e.g., "Answer:") from the prompt if it's already there
        if self.chat_response_prefix:
            cleaned_prompts = []
            for p in prompts:
                if self.chat_response_prefix in p:
                    p = p.replace(self.chat_response_prefix, '').rstrip()
                cleaned_prompts.append(p)
        else:
            cleaned_prompts = prompts

        chat_prompts = [f'{self.chat_prefix} {p} {self.chat_suffix}' for p in cleaned_prompts]

        chat_inputs = self.tokenizer(
            chat_prompts, padding="longest", return_tensors="pt",
            add_special_tokens=True
        )
        chat_inputs.input_ids = chat_inputs.input_ids.to(self.device)
        chat_inputs.attention_mask = chat_inputs.attention_mask.to(self.device)

        return chat_inputs

    def update_analysis_data(self, analysis_data, next_tokens, next_token_logits_dict):
        analysis_data['tokens'].append([self.tokenizer.decode(t) for t in next_tokens])
        analysis_data['token_ids'].append(next_tokens)

        # logits from each model for the next token
        for model in next_token_logits_dict.keys():
            analysis_data[f'logits_{model}'].append(next_token_logits_dict[model].unsqueeze(dim=1))

        return analysis_data
    
    

    def generate(
        self,
        input_ids: Optional[torch.Tensor] = None,
        max_new_tokens: Optional[int] = 100,
        do_sample: bool = False,
        temperature: float = 1.0,
        logits_processor: Optional[LogitsProcessorList] = None,
        return_logits_for_analysis: bool = False,
        prefix_allowed_tokens_fn=None,
        **kwargs
    ):
        base_kwargs = kwargs.copy()
     
        expert_kwargs = kwargs.copy()
   
        antiexpert_kwargs = kwargs.copy()

        # prepare inputs for expert model
        if self.use_chat_format_for_expert:
            print("[generate] Using chat format for expert")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            expert_input_ids = chat_inputs.input_ids.to(input_ids.device)
            expert_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for expert")
            expert_input_ids = input_ids.to(input_ids.device)
        
        if self.use_chat_format_for_antiexpert:
            print("[generate] Using chat format for antiexpert")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            antiexpert_input_ids = chat_inputs.input_ids.to(input_ids.device)
            antiexpert_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for antiexpert")
            antiexpert_input_ids = input_ids.to(input_ids.device)
        
        if self.use_chat_format_for_base:
            print("[generate] Using chat format for base")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            base_input_ids = chat_inputs.input_ids.to(input_ids.device)
            base_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for base")
            base_input_ids = input_ids.to(input_ids.device)

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
      

        eos_token_id_tensor = torch.tensor([self.tokenizer.eos_token_id]).to(input_ids.device)


        if return_logits_for_analysis:
            analysis_data = defaultdict(list)
            print("[generate] Initialized analysis_data: ", analysis_data)


        for step in range(max_new_tokens):
            # prepare model inputs with past_key_values and attention_mask
            base_inputs = self.base.prepare_inputs_for_generation(base_input_ids, **base_kwargs)
            expert_inputs = self.expert.prepare_inputs_for_generation(expert_input_ids, **expert_kwargs)
            antiexpert_inputs = self.antiexpert.prepare_inputs_for_generation(antiexpert_input_ids, **antiexpert_kwargs)

            # DExperts
            base_outputs, expert_outputs, antiexpert_outputs = self.forward(
                base_inputs, expert_inputs, antiexpert_inputs, return_dict=True
            )

            base_next_token_logits = base_outputs.logits[..., -1, :]

            expert_next_token_logits = expert_outputs.logits[..., -1, :]
   
            antiexpert_next_token_logits = antiexpert_outputs.logits[..., -1, :]
        

            # sometimes our experts have extra (irrelevant) tokens at the end of the normal vocabulary
            expert_next_token_logits = expert_next_token_logits[:, :base_next_token_logits.shape[-1]]
            # print ("expert_next_token_logits: ", expert_next_token_logits)

            # DExperts! # this is the actual proxy tuning part
            next_token_logits = (
                base_next_token_logits +
                self.alpha * (expert_next_token_logits - antiexpert_next_token_logits)
            )
     
        
            if step < 2:                                       
                next_token_logits[:, self.tokenizer.eos_token_id] = -float("inf")
                
            
            # logits_processor = LogitsProcessorList([MinLengthLogitsProcessor(min_length=2, eos_token_id=self.tokenizer.eos_token_id)])

            # pre-process logits
            if logits_processor:
                next_token_logits = logits_processor(base_input_ids, next_token_logits)
                # print("logits_processor")
                
            if prefix_allowed_tokens_fn:
                
                mask = torch.full_like(next_token_logits, -math.inf) 
                sent = base_input_ids[0]
                prefix_allowed_tokens = prefix_allowed_tokens_fn(0, sent)
                if len(prefix_allowed_tokens) == 0:
                    raise ValueError(
                        f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
                        f"This means that the constraint is unsatisfiable. Please check your implementation"
                        f"of `prefix_allowed_tokens_fn` "
                    )
                mask[0, prefix_allowed_tokens] = 0
                next_token_logits = next_token_logits + mask

            # warp logits
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
                # print("temperature")
            

            # decode
            if do_sample:
                probs = F.softmax(next_token_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                # print("do_sample")
            else:
                next_tokens = torch.argmax(next_token_logits, dim=-1)
                # print("next_tokens: ", next_tokens)

            next_tokens = (
                next_tokens * unfinished_sequences +
                self.tokenizer.pad_token_id * (1 - unfinished_sequences)
            )

            if return_logits_for_analysis:
                next_token_logits_dict = {
                    'dexperts': next_token_logits,
                    'base': base_next_token_logits,
                    'expert': expert_next_token_logits,
                    'antiexpert': antiexpert_next_token_logits
                }
                analysis_data = self.update_analysis_data(analysis_data, next_tokens, next_token_logits_dict)

            # update model inputs for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            expert_input_ids = torch.cat([expert_input_ids, next_tokens[:, None]], dim=-1)
            antiexpert_input_ids = torch.cat([antiexpert_input_ids, next_tokens[:, None]], dim=-1)
            base_input_ids = torch.cat([base_input_ids, next_tokens[:, None]], dim=-1)

            # update kwargs
            base_kwargs = self._update_model_kwargs_for_generation(base_outputs, base_kwargs)
            expert_kwargs = self._update_model_kwargs_for_generation(expert_outputs, expert_kwargs)
            antiexpert_kwargs = self._update_model_kwargs_for_generation(antiexpert_outputs, antiexpert_kwargs)

            # if eos_token was found in one sentence, set sentence to finished
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )

            # stop when each sentence is finished
            if unfinished_sequences.max() == 0:
                break

        if return_logits_for_analysis:
            for k in analysis_data.keys():
                if k.startswith('logits'):
                    analysis_data[k] = torch.cat(analysis_data[k], dim=1)
            return input_ids, analysis_data

        return input_ids

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        kwargs: Dict[str, Any],
    ) -> Dict[str, Any]:
        # update past_key_values
        kwargs["past_key_values"] = outputs.past_key_values

        # update attention mask
        if "attention_mask" in kwargs:
            attention_mask = kwargs["attention_mask"]
            kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )
        if getattr(outputs, "cache_position", None) is not None:
        # some models already return it
            kwargs["cache_position"] = outputs.cache_position
        else:
            if "cache_position" in kwargs:
                kwargs["cache_position"] = kwargs["cache_position"] + 1
            else:
                # first step: position is sequence-length-1
                seq_len = kwargs["attention_mask"].shape[1]
                kwargs["cache_position"] = torch.arange(seq_len - 1, seq_len, device=kwargs["attention_mask"].device)

        return kwargs

class RegularLlama:
    def __init__(
        self,
        base_name: str,
        tokenizer: AutoTokenizer,
        system_prompt: str = None,
        alpha: float = 1.0,
        chat_response_prefix: str = None,
        model_kwargs: Dict[str, Any] = None
    ):
        """
        chat_response_prefix: For llama chat models, it can be helpful for the response
        to start with a certain prefix to constrain the generation to directly answer
        the question. This makes evaluation on MC datasets easier.
        """

        self.base = AutoModelForCausalLM.from_pretrained(
            base_name, **model_kwargs
        )
       

        self.base.eval()
      

        self.tokenizer = tokenizer
        self.alpha = alpha
        self.device = self.base.device
        self.chat_response_prefix = chat_response_prefix

        # Llama chat experts need different formatting
      
        self.use_chat_format_for_base = True if 'chat' in base_name.lower() else False

        if self.use_chat_format_for_base:
            print("using chat format")
            # chat_prefix goes before the query, and chat_suffix goes after it
            self.chat_prefix = "[INST]"
            self.chat_suffix = "[/INST]"

            if system_prompt:
                self.chat_prefix += f"{B_SYS}{system_prompt}{E_SYS}"

            if self.chat_response_prefix:
                self.chat_suffix += f" {chat_response_prefix}"


    def forward(
        self,
        base_inputs,
        return_dict=None
    ):
        base_outputs = self.base(**base_inputs, return_dict=return_dict)

        return base_outputs

    
    def _get_tokenized_chat_inputs(self, input_ids):
        """Decode input_ids and encode again to insert chat formatting"""

        prompts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # remove response_prefix (e.g., "Answer:") from the prompt if it's already there
        if self.chat_response_prefix:
            cleaned_prompts = []
            for p in prompts:
                if self.chat_response_prefix in p:
                    p = p.replace(self.chat_response_prefix, '').rstrip()
                cleaned_prompts.append(p)
        else:
            cleaned_prompts = prompts

        chat_prompts = [f'{self.chat_prefix} {p} {self.chat_suffix}' for p in cleaned_prompts]

        chat_inputs = self.tokenizer(
            chat_prompts, padding="longest", return_tensors="pt",
            add_special_tokens=True
        )
        chat_inputs.input_ids = chat_inputs.input_ids.to(self.device)
        chat_inputs.attention_mask = chat_inputs.attention_mask.to(self.device)

        return chat_inputs
    
    def update_analysis_data(self, analysis_data, next_tokens, next_token_logits_dict):
        analysis_data['tokens'].append([self.tokenizer.decode(t) for t in next_tokens])
        analysis_data['token_ids'].append(next_tokens)

        # logits from each model for the next token
        for model in next_token_logits_dict.keys():
            analysis_data[f'logits_{model}'].append(next_token_logits_dict[model].unsqueeze(dim=1))

        return analysis_data
    
    

    def generate(
        self,
        input_ids: Optional[torch.Tensor] = None,
        max_new_tokens: Optional[int] = 100,
        do_sample: bool = False,
        temperature: float = 1.0,
        logits_processor: Optional[LogitsProcessorList] = None,
        return_logits_for_analysis: bool = False,
        prefix_allowed_tokens_fn=None,
        **kwargs
    ):
        base_kwargs = kwargs.copy()
     
       

        # prepare inputs for  model
        
        if self.use_chat_format_for_base:
            print("[generate] Using chat format for base")
            chat_inputs = self._get_tokenized_chat_inputs(input_ids)
            base_input_ids = chat_inputs.input_ids.to(input_ids.device)
            base_kwargs['attention_mask'] = chat_inputs.attention_mask
        else:
            print("[generate] Not using chat format for base")
            base_input_ids = input_ids.to(input_ids.device)

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
      

        eos_token_id_tensor = torch.tensor([self.tokenizer.eos_token_id]).to(input_ids.device)


        if return_logits_for_analysis:
            analysis_data = defaultdict(list)
            print("[generate] Initialized analysis_data: ", analysis_data)


        for step in range(max_new_tokens):
            # prepare model inputs with past_key_values and attention_mask
            base_inputs = self.base.prepare_inputs_for_generation(base_input_ids, **base_kwargs)
            
            # DExperts
            base_outputs = self.forward(
                base_inputs, return_dict=True
            )

            base_next_token_logits = base_outputs.logits[..., -1, :]

          
        

            # DExperts! # this is the actual proxy tuning part
            next_token_logits = base_next_token_logits 
     
        
            if step < 2:                                       
                next_token_logits[:, self.tokenizer.eos_token_id] = -float("inf")
                
            
            # logits_processor = LogitsProcessorList([MinLengthLogitsProcessor(min_length=2, eos_token_id=self.tokenizer.eos_token_id)])

            # pre-process logits
            if logits_processor:
                next_token_logits = logits_processor(base_input_ids, next_token_logits)
                # print("logits_processor")
                
            if prefix_allowed_tokens_fn:
                mask = torch.full_like(next_token_logits, -math.inf) 
                sent = base_input_ids[0]
                prefix_allowed_tokens = prefix_allowed_tokens_fn(0, sent)
                if len(prefix_allowed_tokens) == 0:
                    raise ValueError(
                        f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
                        f"This means that the constraint is unsatisfiable. Please check your implementation"
                        f"of `prefix_allowed_tokens_fn` "
                    )
                mask[0, prefix_allowed_tokens] = 0
                next_token_logits = next_token_logits + mask

            # warp logits
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
                # print("temperature")
            

            # decode
            if do_sample:
                probs = F.softmax(next_token_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                # print("do_sample")
            else:
                next_tokens = torch.argmax(next_token_logits, dim=-1)
                # print("next_tokens: ", next_tokens)

            next_tokens = (
                next_tokens * unfinished_sequences +
                self.tokenizer.pad_token_id * (1 - unfinished_sequences)
            )

            if return_logits_for_analysis:
                next_token_logits_dict = {
                    'base': base_next_token_logits,
                }
                analysis_data = self.update_analysis_data(analysis_data, next_tokens, next_token_logits_dict)

            # update model inputs for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            base_input_ids = torch.cat([base_input_ids, next_tokens[:, None]], dim=-1)

            # update kwargs
            base_kwargs = self._update_model_kwargs_for_generation(base_outputs, base_kwargs)
            # if eos_token was found in one sentence, set sentence to finished
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )

            # stop when each sentence is finished
            if unfinished_sequences.max() == 0:
                break

        if return_logits_for_analysis:
            for k in analysis_data.keys():
                if k.startswith('logits'):
                    analysis_data[k] = torch.cat(analysis_data[k], dim=1)
            return input_ids, analysis_data

        return input_ids

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        kwargs: Dict[str, Any],
    ) -> Dict[str, Any]:
        # update past_key_values
        kwargs["past_key_values"] = outputs.past_key_values

        # update attention mask
        if "attention_mask" in kwargs:
            attention_mask = kwargs["attention_mask"]
            kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )
        if getattr(outputs, "cache_position", None) is not None:
        # some models already return it
            kwargs["cache_position"] = outputs.cache_position
        else:
            if "cache_position" in kwargs:
                kwargs["cache_position"] = kwargs["cache_position"] + 1
            else:
                # first step: position is sequence-length-1
                seq_len = kwargs["attention_mask"].shape[1]
                kwargs["cache_position"] = torch.arange(seq_len - 1, seq_len, device=kwargs["attention_mask"].device)

        return kwargs


def ensure_dir(d):
    if not os.path.exists(d):
        os.makedirs(d, exist_ok=True)


@torch.inference_mode()
def generate_completions(
    model,
    tokenizer,
    prompts,
    batch_size=1,
    add_special_tokens=True,
    disable_tqdm=False,
    temperature=1.0,
    logits_processor=None,
    prefix_allowed_tokens_fn=None,
    return_logits_for_analysis=False,
    **generation_kwargs, 
    
):
    generations = []
    outputs = []
    if not disable_tqdm:
        progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")

    num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
     
    all_results = []
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        tokenized_prompts = tokenizer(
            batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens
        )
        # print ("batch_prompts: ", batch_prompts)
        # print ("tokenized_prompt: ", tokenized_prompts)
        if hasattr(model, "device"):                 # DExpertsLlama
            device = model.device
            # print ("device = model.device")
        else:                                        # vanilla HF model
            device = next(model.parameters()).device
            # print ("next(model.parameters()).devicedevice = next(model.parameters()).device")
        batch_input_ids = tokenized_prompts['input_ids'].to(device)
        attention_mask = tokenized_prompts['attention_mask'].to(device)

        if return_logits_for_analysis:
            batch_outputs, results = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                logits_processor=logits_processor,
                prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                temperature=temperature,
                return_logits_for_analysis=return_logits_for_analysis,
                **generation_kwargs
            )

            results = flatten_batch_results(results)
            shortened_results = summarize_results(results)
            all_results.extend(shortened_results)
        
        else:
            batch_outputs = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                logits_processor=logits_processor,
                prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                temperature=temperature,
                **generation_kwargs
            )
            results = []
        
        # to support the logits processing below when using DExperts with mixed tokenizers
        if isinstance(batch_input_ids, dict):
            batch_input_ids = batch_input_ids['llama']

    
        batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
        print ("batch_outputs: ", batch_outputs)
        batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)

        # duplicate the prompts to match the number of return sequences
        batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
        batch_generations = [
            output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
        ]

        generations += batch_generations
        outputs += batch_outputs

        if not disable_tqdm:
            progress.update(len(batch_prompts)//num_return_sequences)
    
    

    assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
    assert len(outputs) == len(prompts) * num_return_sequences, "number of outputs should be equal to number of prompts * num_return_sequences"
    # return generations, logits_for_analysis
    return outputs, generations, all_results


def add_pad_token(tokenizer, padding_side="left"):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = padding_side
    return tokenizer

def load_dexperts_model_and_tokenizer(
    base_name: str,
    expert_name: str,
    antiexpert_name: str,
    tokenizer_name: str,
    device_map: str = "auto",
    system_prompt: str = None,
    alpha: float = 1.0,
    chat_response_prefix: str = None,
    load_in_8bit: bool = False,
    load_in_4bit: bool = False,
    use_fast_tokenizer: bool = True,
    padding_side: str = "left",
    proxy_tune: bool = False,
):
    
    bnb_cfg = None

    if load_in_8bit:
        bnb_cfg = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
        
    if load_in_4bit:
        bnb_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",         # {nf4, fp4}; nf4 is standard
            bnb_4bit_compute_dtype=torch.bfloat16,  # or torch.float16 if BF16 isn’t ideal
        )

    model_kwargs = {
        'device_map': device_map,
        'torch_dtype': torch.bfloat16,
        'quantization_config': bnb_cfg,
    }

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATHS[tokenizer_name], use_fast=use_fast_tokenizer)
    print(f"[Loader] Tokenizer   : {MODEL_PATHS[tokenizer_name]}")
    tokenizer = add_pad_token(tokenizer, padding_side)
   
    if proxy_tune:
        model = DExpertsLlama(base_name=MODEL_PATHS[base_name], expert_name=MODEL_PATHS[expert_name], antiexpert_name=MODEL_PATHS[antiexpert_name], tokenizer=tokenizer, system_prompt=system_prompt, alpha=alpha, chat_response_prefix=chat_response_prefix, model_kwargs=model_kwargs)
        print(f"[Loader] Base model   : {MODEL_PATHS[base_name]}")
        print(f"[Loader] Expert model : {MODEL_PATHS[expert_name]}")
        print(f"[Loader] Anti‑expert  : {MODEL_PATHS[antiexpert_name]}")
    else:
        model = RegularLlama(base_name=MODEL_PATHS[base_name], tokenizer=tokenizer, system_prompt=system_prompt, alpha=alpha, chat_response_prefix=chat_response_prefix, model_kwargs=model_kwargs)
        print(f"[Loader] Base model   : {MODEL_PATHS[base_name]}")

    return model, tokenizer


START_DATE = datetime.now().strftime("%Y%m%d")

def extract_label(text: str, labels: List[str]) -> str:
    try:
        obj = json.loads(text)
        # handle either "label" or "Label"
        lbl = obj.get("label")
        if isinstance(lbl, str) and lbl in labels:
            return lbl
    except json.JSONDecodeError:
        pass   # not JSON – fall through
    print("unknown prediction: ", text)
    return "unknown"

from pathlib import Path

def load_jsonl_records(path_str: str):
    p = Path(path_str)
    if not p.exists():
        raise FileNotFoundError(f"Dataset file not found: {p}")
    if p.stat().st_size == 0:
        raise ValueError(f"Dataset file is empty: {p}")
    recs = []
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:  # skip blank lines
                recs.append(json.loads(line))
    return recs

def main(args):
   
    start_time = time.time()
    random.seed(args.seed)
    
    save_dir = f"/fsroproxy_tuning/results/{args.dataset}"
    
    ensure_dir(save_dir)
    
    max_examples = args.max_examples
    load_in_8bit = args.load_in_8bit
    load_in_4bit = args.load_in_4bit
    eval_batch_size = 1 
    
    test_data = load_jsonl_records(DATASET_PATHS[args.dataset])
   
    # test_data = pd.read_json(
    #     DATASET_PATHS[args.dataset],
    #     lines=True, 
    #     dtype={'gold': str}         
    #  ).to_dict("records")
    test_data = random.sample(test_data, max_examples)

    prompts = [example["prompt"] for example in test_data]
 
    model, tokenizer = load_dexperts_model_and_tokenizer(
        base_name=args.base_name,
        expert_name=args.expert_name,
        antiexpert_name=args.anti_expert_name,
        tokenizer_name=args.tokenizer_name,
        load_in_8bit=load_in_8bit,
        load_in_4bit=load_in_4bit,
        use_fast_tokenizer=True,
        device_map='auto', 
        alpha=args.alpha,
        proxy_tune=args.proxy_tune,
            
    )
    
    class ReasonAndLabel(BaseModel):
        reason: str
        label: str
    
    schema = ReasonAndLabel.model_json_schema()
    print("schema", schema)
    schema["properties"]["label"]["enum"] = LABELS[args.dataset] 
    print("schema after defining", schema)
    
    json_parser = JsonSchemaParser(schema)

    tokenizer_data = build_token_enforcer_tokenizer_data (tokenizer, tokenizer.vocab_size)
    prefix_func = build_transformers_prefix_allowed_tokens_fn(tokenizer_data, json_parser)

    outputs, predicted_labels, all_results = generate_completions(
        model=model,
        tokenizer=tokenizer,
        prompts=prompts,
        batch_size=eval_batch_size,
        max_new_tokens=args.max_new_tokens,       
        do_sample=False,        
        logits_processor=None, 
        prefix_allowed_tokens_fn=prefix_func,
        num_return_sequences=1,
        return_logits_for_analysis=args.return_logits_for_analysis, 
    )
    
    if all_results:
        torch.save(all_results, f"{save_dir}/logits_analysis/logits_{args.base_name}_base_{args.expert_name}_expert_{args.anti_expert_name}_anti_{args.tokenizer_name}_tokenizer_{START_DATE}_{args.seed}.pkl")
        print(f"wrote to {save_dir}/logits_analysis/logits_{args.base_name}_base_{args.expert_name}_expert_{args.anti_expert_name}_anti_{args.tokenizer_name}_tokenizer_{START_DATE}_{args.seed}.pkl")
    

    records = [
    {
        "task_id":   ex["task_id"],
        "prompt":    ex["prompt"],
        "full_output": output,   
        "prediction": pred_label,
        "predicted_label": extract_label(pred_label, LABELS[args.dataset]),     
        "gold":      ex["gold"],
    }
    for ex, output, pred_label in zip(test_data, outputs, predicted_labels)
]
    num_unknowns = 0 
    choices = LABELS[args.dataset]
    for pred_label in predicted_labels:
        extracted_label = extract_label(pred_label, LABELS[args.dataset])
        if extracted_label not in choices:
            num_unknowns +=1 
    predictions_out_path = os.path.join(f"{save_dir}/outputs", f"predictions_{args.base_name}_base_{args.expert_name}_expert_{args.anti_expert_name}_anti_{args.alpha}_alpha_{START_DATE}.jsonl")
    pd.DataFrame(records).to_json(predictions_out_path, orient="records", lines=True)
    print(f"wrote to {predictions_out_path}")


    acc = accuracy_score([r["gold"] for r in records],
                         [r["predicted_label"] for r in records])
    f1  = f1_score(    [r["gold"] for r in records],
                       [r["predicted_label"] for r in records],
                       labels=choices, average="macro", zero_division=0)
    elapsed = time.time() - start_time    
    
    row = {
        "base":        args.base_name,
        "expert":     args.expert_name,
        "anti_expert": args.anti_expert_name,
        "tokenizer":   args.tokenizer_name,
        "alpha": args.alpha,
        "accuracy":    round(acc, 4),
        "macro_f1":    round(f1, 4),
        "pct_unknown": round(100 * num_unknowns / len(predicted_labels), 2),
        "runtime_sec": round(elapsed, 2),
        "run_date": START_DATE,

    }

    csv_path = os.path.join(save_dir, "metrics_table_alpha.csv")
    pd.DataFrame([row]).to_csv(
        csv_path,
        mode="a",
        header=not os.path.exists(csv_path),
        index=False,
    )
    
    del model
    del tokenizer
    import gc
    gc.collect()
    torch.cuda.empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--expert_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--anti_expert_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        required=True,

    )
    parser.add_argument(
        "--dataset",
        choices=["mednli", "mtsample", "fall_prediction"],
        type=str,
    )
    
    parser.add_argument(
        "--proxy_tune",
        action="store_true",
    )
    
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        required=True
    )
    parser.add_argument(
        "--max_examples",
        type=int,
        required=True
    )
    parser.add_argument(
        "--return_logits_for_analysis",
        action="store_true"
    )
    parser.add_argument(
        "--load_in_8bit",
        action="store_true"
    )
    parser.add_argument(
        "--load_in_4bit",
        action="store_true"
    )
    parser.add_argument(
        "--seed",
        type=int,
        required=True
    )
    parser.add_argument(
        "--alpha",
        type=float,
        required=True
    )
        
        
        
    
    args = parser.parse_args()
    main(args)
