import os
import pickle
import torch
import math
from .models import *
from .datasets import *
from .attributions import Attributions
from .utils import *
from tqdm import tqdm

class Experiment:
    def __init__(self, args):
        self.dataset = dataset_classes[args.dataset](args)
        self.models = models_classes[args.model](args)
        self.attributions = Attributions(self.models, args)
        # self.dtype = dtypes_dict[args.dtype]
        self.args = args
    
    def save(self, outs, quantifiers, force=False, generations=False, tmp=False, experiment_name=None, pref=""):
        filename = self.dataset.get_results_filename(
            self.models.name, 
            quantifiers,
            debug=self.args.debug,
            generations=generations,
            experiment_name=experiment_name,
            pref=pref
        )

        if tmp:
            relative_components = filename.split('/')
            relative_components = [x for x in relative_components if x != "."]
            basename = relative_components[-1]
            tmp_path = os.path.join("/tmp/<name>", *relative_components[:-1], basename)
            filename = tmp_path

        if os.path.exists(filename) and not force:
            self.print(f"File '{filename}' already exists.")
        else:
            # Create file if it doesn't exist
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with open(filename, 'wb') as file:
                pickle.dump(outs, file)
            print(f"File '{filename}' created and object pickled.")
    
    def load(self, quantifiers, generations=False, experiment_name=None, pref=""):
        filename = self.dataset.get_results_filename(
            self.models.name, 
            quantifiers,
            debug=self.args.debug,
            generations=generations,
            experiment_name=experiment_name,
            pref=pref,
        )
        if not os.path.exists(filename):
            print(f"NOT FOUND: {filename}")
            return None
        else:
            self.print(f"Loading {filename}")
            with open(filename, 'rb') as file:
                outs = pickle.load(file)
            self.print(f"File '{filename}' loaded")
            return outs
        
    def print(self, *args, **kwargs):
        if self.args.debug:
            print(*args, **kwargs)
    
    def __call__(
        self, 
        force_recompute=False,
    ):
        self.dataset.prepare(self.models)

        for quantifiers in self.dataset.get_quantifiers():
            n_tries = 0
            clean()
            while n_tries < 10:
                try:
                    print(f"#################### {self.dataset.label_of(quantifiers)} ####################")
                    if self.is_complete_logprobs(quantifiers, None, None, generations=False):
                        print(f"!!!!!!!!!!!!!!!!!!!!! {self.dataset.label_of(quantifiers)} is complete !!!!!!!!!!!!!!!!!!!!! ")
                        break
                    texts = self.dataset.get_prompts(quantifiers, n_prompts=self.args.n_prompts)
                    texts = list(set(texts))
                    outs = self.load(quantifiers)
                    self.print(texts[:5])
                    # self.print(self.dataset.dataset_formatted[quantifiers["split"]]["prompt"][:10])
                    if outs is None or force_recompute:
                        outs = []
                        for i in range(0, len(texts), self.args.save_freq):
                            j = min(i + self.args.save_freq, len(texts))
                            new_outs = self.attributions.compute(
                                texts[i:j],
                                batch_size=self.args.batch_size,
                            )

                            for d in new_outs:
                                d["label"] = self.dataset.label_of(quantifiers)
                            outs = outs + new_outs
                        self.save(outs, quantifiers, force=True)
                    break
                except RuntimeError as e:
                    if "CUDA out of memory" in str(e):
                        clean()
                        new_batch_size = max(1, int(0.75*self.args.batch_size))
                        print("!"*50, f"CUDA out of memory error encountered. Retrying with half batch size: {self.args.batch_size} -> {new_batch_size}", "!"*50 )
                        setattr(self.args, "batch_size", new_batch_size)
                        n_tries += 1
                    else:
                        raise e 

    def _print_model_generations(self, model):
        self.models.load_gpu(model)

        for quantifiers in self.dataset.get_quantifiers():
            texts = self.dataset.get_prompts(quantifiers, n_prompts=self.args.n_prompts)
            texts = list(set(texts))

            print("#"*50)
            print(f"{model} - {self.dataset.label_of(quantifiers)}")
            print("#"*50)

            for i, text in enumerate(texts):
                print(f"SAMPLE {i}")
                completion = self.models.generate(text, max_new_tokens=30, apply_formatting=False)
                print(f"String: \n{repr(completion)}\n")
                print(f"Printed: \n{completion}")
                print('-'*50)
            
            print()
        
        self.models.clean_loaded_model()

    def dump_generations(self, n_samples=16, max_new_tokens=30, only_quantifier=None):
        fix_seed(seed=42)
        self.dataset.prepare(self.models)
        self.models.clean_loaded_model()
        self.models.load_gpu("finetuned")
        all_quantifiers = self.dataset.get_quantifiers()
        if only_quantifier is not None:
            all_quantifiers = [all_quantifiers[only_quantifier]]
        for quantifiers in all_quantifiers:
            n_tries = 0
            clean()
            while n_tries < 10:
                try:
                    print(f"#################### {self.dataset.label_of(quantifiers)} ####################")
                    if self.is_complete_logprobs(quantifiers, None, None, generations=True):
                        print(f"!!!!!!!!!!!!!!!!!!!!! {self.dataset.label_of(quantifiers)} is complete !!!!!!!!!!!!!!!!!!!!! ")
                        break

                    prompts_initial = self.dataset.get_prompts(quantifiers, n_prompts = self.args.n_prompts)
                    all_outputs = [
                        {"prompt": prompt, "model_output": [], "generations": [], "gen_start_idxs": []} 
                        for prompt in prompts_initial
                    ]
                    prompts = [
                        (prompt, idx )
                        for idx, prompt in enumerate(prompts_initial)
                        for _ in range(n_samples)
                    ]

                    progress_bar = tqdm(total=len(prompts))
                    for i in range(0, len(prompts), self.args.batch_size):
                        j = min(len(prompts), i + self.args.batch_size)
                        batch = [p[0] for p in prompts[i:j]]
                        idxs = [p[1] for p in prompts[i:j]]
                        # batch_max_new = max_new_tokens[i*batch_size:(i+1)*batch_size]

                        # print(i*batch_size, (i+1)*batch_size, batch)

                        batch_inputs = self.models.tokenizer(batch, padding=True, truncation=False, return_tensors='pt')

                        batch_input_ids = batch_inputs['input_ids'].to("cuda")
                        batch_attention_mask = batch_inputs['attention_mask'].to("cuda")
                        # position_ids = batch_attention_mask.long().cumsum(-1) - 1
                        # position_ids.masked_fill_(batch_attention_mask == 0, 1)

                        try:
                            outputs = self.models.model_in_gpu.generate(
                                batch_input_ids, 
                                attention_mask=batch_attention_mask, 
                                max_new_tokens=max_new_tokens,
                                do_sample=True,
                                temperature=0.6,
                            )
                            batch_outputs = self.models.tokenizer.batch_decode(outputs, skip_special_tokens=True)
                        except Exception as e:
                            print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! EXCEPTION !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                            print(e)
                            if 'out of memory' in str(e):
                                print("CUDA out of memory error occurred")
                                raise e
                            outputs = None
                            batch_outputs = batch

                        gen_start_idxs = [len(self.models.tokenizer.decode(batch_input_ids[i], skip_special_tokens=True)) for i in range(len(batch_input_ids))]

                        for idx, out, gen_start_idx in zip(idxs, batch_outputs, gen_start_idxs):
                            all_outputs[idx]["model_output"].append(out)
                            all_outputs[idx]["generations"].append(out[gen_start_idx:])
                            all_outputs[idx]["gen_start_idxs"].append(gen_start_idx)

                            self.print(repr(out[max(0, gen_start_idx-160):]))

                        # clear cache
                        del batch_inputs, batch_input_ids, batch_attention_mask, outputs, batch_outputs
                        torch.cuda.empty_cache()

                        progress_bar.update(j-i)
                    
                    self.save(all_outputs, quantifiers, force=True, generations=True)
                    break
                except RuntimeError as e:
                    if "CUDA out of memory" in str(e):
                        clean()
                        new_batch_size = max(1, int(0.75*self.args.batch_size))
                        print("!"*50, f"CUDA out of memory error encountered. Retrying with half batch size: {self.args.batch_size} -> {new_batch_size}", "!"*50 )
                        setattr(self.args, "batch_size", new_batch_size)
                        n_tries += 1
                    else:
                        raise e 

    def check_model_generations(self):
        self.dataset.prepare(self.models)
        self._print_model_generations("finetuned")
        self._print_model_generations("pretrained")

    def is_complete_logprobs(self, quantifiers, grid, all_examples, generations=True, **kwargs):
        logprob_data = self.load(quantifiers, generations=generations, **kwargs)
        return logprob_data is not None

    @torch.no_grad()
    def dump_model_written_evals_logprobs(
        self, 
        idx=0, 
        n_splits=1,
        grid = [0.0, 0.25, 0.5, 0.6, 0.75, 0.9, 0.95, 1.0, 1.05, 1.1, 1.25, 1.5, 1.75, 2.0],
        single_gpu=True
    ):
        # assert self.args.model in ["llama", "gemma"], "Only llama and gemma supported"
        assert self.args.dataset == "model_written_evals", "Only model_written_evals supported"

        self.dataset.prepare(self.models)
        self.models.load_gpu("finetuned")

        all_quantifiers = self.dataset.get_quantifiers()
        size_split = math.ceil(len(all_quantifiers) / n_splits)
        my_quantifiers = all_quantifiers[idx*size_split : min(len(all_quantifiers), (idx+1)*size_split)]

        print(f"My quantifiers: {my_quantifiers}")

        for quantifier in my_quantifiers:
            n_tries = 0
            while n_tries < 10:
                try:
                    clean()
                    all_examples = self.dataset.get_raw(quantifier, n_prompts=self.args.n_prompts)
                    outs = []

                    is_complete = self.is_complete_logprobs(quantifier, grid, all_examples)
                    if is_complete:
                        print(f"!!!!!!!!!!!!!!!!!!!!! {self.dataset.label_of(quantifier)} is complete !!!!!!!!!!!!!!!!!!!!! ")
                    else:
                        progress_bar = tqdm(total=len(all_examples) * len(grid))
                        for i in range(0, len(all_examples), self.args.batch_size):
                            j = min(len(all_examples), i+self.args.batch_size)
                            examples = all_examples[i:j]
                            prompts = [example["prompt"] for example in examples]

                            # Get output of the actual pretrained model
                            out_pretrained = self.models.batched_inference(
                                prompts, 
                                use_cache=False, 
                                use_pretrained_gpu=True
                            )

                            total_logprobs_pretrained = get_many_total_probs(
                                out_pretrained, 
                                sorted([examples[0]["answer_matching_behavior"], examples[0]["answer_not_matching_behavior"]]), 
                                self.models
                            )

                            del out_pretrained
                            clean()

                            for alpha in grid:
                                out = self.models.batched_inference(
                                    prompts, 
                                    use_cache=False, 
                                    callback=self.models.get_callback_shrink(alpha, use_timer=False, single_gpu=single_gpu)
                                )

                                total_logprobs = get_many_total_probs(
                                    out, 
                                    sorted([examples[0]["answer_matching_behavior"], examples[0]["answer_not_matching_behavior"]]), 
                                    self.models
                                )

                                for k, (logprobs, logprobs_pretrained) in enumerate(zip(total_logprobs, total_logprobs_pretrained)):
                                    example = examples[k]
                                    outs.append({
                                        "idx": i+k,
                                        "behavior": quantifier["behavior"],
                                        "alpha": alpha,
                                        "answer_matching_behavior": example["answer_matching_behavior"],
                                        "answer_not_matching_behavior": example["answer_not_matching_behavior"],
                                        "logprobs": logprobs,
                                        "logprobs_pretrained": logprobs_pretrained
                                    })
                                progress_bar.update(j-i)
                                
                            self.save(outs, quantifier, force=True, generations=True)
                    break
                except RuntimeError as e:
                    if "CUDA out of memory" in str(e):
                        clean()
                        new_batch_size = max(1, int(0.75*self.args.batch_size))
                        print("!"*50, f"CUDA out of memory error encountered. Retrying with half batch size: {self.args.batch_size} -> {new_batch_size}", "!"*50 )
                        setattr(self.args, "batch_size", new_batch_size)
                        n_tries += 1
                    else:
                        raise e 

    @torch.no_grad()
    def dump_multiple_choice_alpha_scaling(
        self, 
        idx=0, 
        n_splits=1,
        grid = [0.0, 0.25, 0.5, 0.6, 0.75, 0.9, 0.95, 1.0, 1.05, 1.1, 1.25, 1.5, 1.75, 2.0],
        single_gpu=True
    ):
        # assert self.args.model in ["llama", "gemma"], "Only llama and gemma supported"
        # assert self.args.dataset == "model_written_evals", "Only model_written_evals supported"

        all_quantifiers = self.dataset.get_quantifiers()
        size_split = math.ceil(len(all_quantifiers) / n_splits)
        my_quantifiers = all_quantifiers[idx*size_split : min(len(all_quantifiers), (idx+1)*size_split)]
        my_tasks = list({quantifier["task"] for quantifier in my_quantifiers})

        self.dataset.prepare(self.models, tasks=my_tasks)

        print(f"My quantifiers: {my_quantifiers}")

        if all([self.is_complete_logprobs(quantifier, None, None) for quantifier in my_quantifiers]):
            print(f"!!!!!!!!!!!!!!!!!!!!! {my_quantifiers} all complete !!!!!!!!!!!!!!!!!!!!! ")
            return
        

        # self.models.load_gpu("finetuned")
        if not single_gpu:  
            self.models.load_finetuned_and_pretrained_to_gpu()
        else:
            self.models.load_gpu("finetuned")

        for quantifier in my_quantifiers:
            n_tries = 0
            clean()
            while n_tries < 10:
                try:
                    all_examples = self.dataset.get_prompts(quantifier, n_prompts=self.args.n_prompts)
                    if self.args.debug:
                        print(f"Got {len(all_examples)} examples")
                    outs = []

                    is_complete = self.is_complete_logprobs(quantifier, grid, all_examples)
                    if is_complete:
                        print(f"!!!!!!!!!!!!!!!!!!!!! {self.dataset.label_of(quantifier)} is complete !!!!!!!!!!!!!!!!!!!!! ")
                    else:
                        progress_bar = tqdm(total=len(all_examples) * len(grid))
                        for i in range(0, len(all_examples), self.args.batch_size):
                            j = min(len(all_examples), i+self.args.batch_size)
                            examples = all_examples[i:j]
                            # example_idxs = [p[0] for p in examples]
                            # choice_idxs = [p[1] for p in examples]
                            contexts = [p[2] for p in examples]
                            completions = [p[3] for p in examples]

                            # Get output of the actual pretrained model
                            avg_logprobs_pretrained = self.models.batched_completions_logprobs(
                                contexts, completions, 
                                use_cache=False, 
                                use_pretrained_gpu=True
                            )

                            clean()

                            for alpha in grid:
                                avg_logprobs = self.models.batched_completions_logprobs(
                                    contexts, completions, 
                                    use_cache=False, 
                                    callback=self.models.get_callback_shrink(
                                        alpha, 
                                        use_timer=False, 
                                        # single_gpu=single_gpu
                                    ),
                                )

                                # print(avg_logprobs)

                                for k, (logprobs_pretrained, logprobs) in enumerate(zip(avg_logprobs_pretrained, avg_logprobs)):
                                    # print(f"{k} {alpha} {logprobs}")
                                    example_idx, choice_idx, context, completion, is_correct = examples[k]
                                    outs.append({
                                        "idx": example_idx,
                                        "choice_idx": choice_idx,
                                        "split": quantifier["split"],
                                        "alpha": alpha,
                                        "is_correct": is_correct,
                                        "logprobs": logprobs,
                                        "logprobs_pretrained": logprobs_pretrained
                                    })
                                progress_bar.update(j-i)
                                
                            # self.save(outs, quantifier, force=True, generations=True, tmp=True)
                        self.save(outs, quantifier, force=True, generations=True, tmp=False)
                    break
                except RuntimeError as e:
                    if "CUDA out of memory" in str(e):
                        clean()
                        new_batch_size = max(1, int(0.75*self.args.batch_size))
                        print("!"*50, f"CUDA out of memory error encountered. Retrying with half batch size: {self.args.batch_size} -> {new_batch_size}", "!"*50 )
                        setattr(self.args, "batch_size", new_batch_size)
                        n_tries += 1
                    else:
                        raise e 

    @torch.no_grad()
    def dump_multiple_choice_weight_interpolation(
        self, 
        idx=0, 
        n_splits=1,
        grid = [0.0, 0.25, 0.5, 0.6, 0.75, 0.9, 0.95, 1.0, 1.05, 1.1, 1.25, 1.5, 1.75, 2.0],
    ):
        # assert self.args.model in ["llama", "gemma"], "Only llama and gemma supported"
        # assert self.args.dataset == "model_written_evals", "Only model_written_evals supported"

        all_quantifiers = self.dataset.get_quantifiers()
        size_split = math.ceil(len(all_quantifiers) / n_splits)
        my_quantifiers = all_quantifiers[idx*size_split : min(len(all_quantifiers), (idx+1)*size_split)]
        my_tasks = list({quantifier["task"] for quantifier in my_quantifiers})

        self.dataset.prepare(self.models, tasks=my_tasks)

        print(f"My quantifiers: {my_quantifiers}")

        for quantifier in my_quantifiers:
            all_examples = self.dataset.get_prompts(quantifier, n_prompts=self.args.n_prompts)
            if self.args.debug:
                print(f"Got {len(all_examples)} examples")
            outs = []

            is_complete = self.is_complete_logprobs(quantifier, grid, all_examples, experiment_name="weight_interpolation")
            if is_complete:
                print(f"!!!!!!!!!!!!!!!!!!!!! {self.dataset.label_of(quantifier)} is complete !!!!!!!!!!!!!!!!!!!!! ")
            else:
                avg_logprobs_pretrained_per_batch = {i:None for i in range(0, len(all_examples), self.args.batch_size)}

                progress_bar = tqdm(total=len(all_examples))
                self.models.clean_loaded_model()
                for i in range(0, len(all_examples), self.args.batch_size):
                    j = min(len(all_examples), i+self.args.batch_size)
                    examples = all_examples[i:j]
                    contexts = [p[2] for p in examples]
                    completions = [p[3] for p in examples]

                    # Get output of the actual pretrained model
                    avg_logprobs_pretrained_per_batch[i] = self.models.batched_completions_logprobs(
                        contexts, completions, 
                        use_cache=False, 
                        use_pretrained_gpu=True
                    )
                    progress_bar.update(j-i)

                self.models._model_pretrained_gpu = None
                clean()
                
                progress_bar = tqdm(total=len(all_examples) * len(grid))
                timer = Timer(active=self.args.debug)
                timer.checkpoint("start")
                for alpha in grid:
                    # timer.checkpoint(f"interpolate {alpha}")
                    self.models.weight_space_interpolation(alpha)
                    # self.models.load_gpu("finetuned")
                    for i in range(0, len(all_examples), self.args.batch_size):
                        j = min(len(all_examples), i+self.args.batch_size)
                        examples = all_examples[i:j]
                        # example_idxs = [p[0] for p in examples]
                        # choice_idxs = [p[1] for p in examples]
                        contexts = [p[2] for p in examples]
                        completions = [p[3] for p in examples]

                        # Get output of the actual pretrained model
                        if avg_logprobs_pretrained_per_batch[i] is None:
                            timer.checkpoint(f"get logprobs pretrained i={i}")
                            avg_logprobs_pretrained_per_batch[i] = self.models.batched_completions_logprobs(
                                contexts, completions, 
                                use_cache=False, 
                                use_pretrained_gpu=True
                            )

                        clean()

                        timer.checkpoint(f"get logprobs interpolated i={i}")
                        avg_logprobs = self.models.batched_completions_logprobs(
                            contexts, completions, 
                            use_cache=False,
                        )

                        # print(avg_logprobs)
                        timer.checkpoint(f"append to results i={i}")
                        for k, (logprobs_pretrained, logprobs) in enumerate(zip(avg_logprobs_pretrained_per_batch[i], avg_logprobs)):
                            # print(f"{k} {alpha} {logprobs}")
                            example_idx, choice_idx, context, completion, is_correct = examples[k]
                            outs.append({
                                "idx": example_idx,
                                "choice_idx": choice_idx,
                                "split": quantifier["split"],
                                "alpha": alpha,
                                "is_correct": is_correct,
                                "logprobs": logprobs,
                                "logprobs_pretrained": logprobs_pretrained
                            })


                        progress_bar.update(j-i)
                        
                    self.save(outs, quantifier, force=True, generations=True, tmp=True, experiment_name="weight_interpolation")
                self.save(outs, quantifier, force=True, generations=True, tmp=False, experiment_name="weight_interpolation")

    @torch.no_grad()
    def dump_generations_interpolation(
        self, 
        idx=0, 
        n_splits=1,
        mode="alpha_scaling",
        grid = [0.0, 0.25, 0.5, 0.6, 0.75, 0.9, 0.95, 1.0, 1.05, 1.1, 1.25, 1.5, 1.75, 2.0],
        single_gpu=True,
        max_new_tokens=256
    ):
        # assert self.args.model in ["llama", "gemma"], "Only llama and gemma supported"
        # assert self.args.dataset == "model_written_evals", "Only model_written_evals supported"

        all_quantifiers = self.dataset.get_quantifiers()
        size_split = math.ceil(len(all_quantifiers) / n_splits)
        my_quantifiers = all_quantifiers
        self.dataset.prepare(self.models)

        get_additional_kwargs=lambda input_ids: self.dataset.additional_generation_arguments(self.models.tokenizer, input_ids)

        experiment_name = f"interpolated_generations_{mode}"
        pref = f"idx_{idx}_nsplits_{n_splits}__"

        print(f"My quantifiers: {my_quantifiers}")

        if mode == "alpha_scaling":
            if not single_gpu:  
                print("Loading finetuned and pretrained models in balanced way across devices")
                self.models.load_finetuned_and_pretrained_to_gpu()
            else:
                print("Just loading finetuned because there is only one GPU")
                self.models.load_gpu("finetuned")

        for quantifier in my_quantifiers:
            print("n_prompts", self.args.n_prompts)
            all_examples = self.dataset.get_prompts(quantifier, n_prompts=self.args.n_prompts)
            size_split = math.ceil(len(all_examples) / n_splits)
            all_examples = all_examples[idx*size_split : min(len(all_examples), (idx+1)*size_split)]
            if self.args.debug:
                print(f"Got {len(all_examples)} examples")
            outs = []

            is_complete = self.is_complete_logprobs(quantifier, grid, all_examples, experiment_name=experiment_name, pref=pref)
            if is_complete:
                print(f"!!!!!!!!!!!!!!!!!!!!! {self.dataset.label_of(quantifier)} is complete !!!!!!!!!!!!!!!!!!!!! ")
            else:
                completions_pretrained_per_batch = {i:None for i in range(0, len(all_examples), self.args.batch_size)}

                progress_bar = tqdm(total=len(all_examples))

                if mode == "weight_interpolation":
                    self.models.clean_loaded_model()

                for i in range(0, len(all_examples), self.args.batch_size):
                    j = min(len(all_examples), i+self.args.batch_size)
                    examples = all_examples[i:j]
                    prompts = [p[1] for p in examples]

                    # Get output of the actual pretrained model
                    completions_pretrained_per_batch[i] = self.models.tokenizer.batch_decode(
                        self.models.generate_batch(
                            prompts, use_cache=True, use_pretrained_gpu=True, max_new_tokens=max_new_tokens,
                            get_additional_kwargs=get_additional_kwargs, do_sample=False,
                        ), 
                        skip_special_tokens=True
                    )

                    progress_bar.update(j-i)
                
                progress_bar = tqdm(total=len(all_examples) * len(grid))
                timer = Timer(active=self.args.debug)
                timer.checkpoint("start")
                for alpha in grid:
                    # timer.checkpoint(f"interpolate {alpha}")
                    if mode == "weight_interpolation":
                        self.models._model_pretrained_gpu = None
                        clean()
                        self.models.weight_space_interpolation(alpha)
                        callback = None
              
                        # self.models.load_gpu("finetuned")

                    # self.models.load_gpu("finetuned")
                    for i in range(0, len(all_examples), self.args.batch_size):
                        j = min(len(all_examples), i+self.args.batch_size)
                        examples = all_examples[i:j]
                        # example_idxs = [p[0] for p in examples]
                        # choice_idxs = [p[1] for p in examples]
                        prompts = [p[1] for p in examples]

                        clean()

                        if mode == "alpha_scaling":
                            callback = self.models.get_callback_shrink(
                                alpha, 
                                use_timer=False,
                                single_gpu=True
                            )

                        timer.checkpoint(f"get logprobs interpolated i={i}")
                        completions = self.models.tokenizer.batch_decode(
                            self.models.generate_batch(
                                prompts, use_cache=True, callback=callback, max_new_tokens=max_new_tokens,
                                get_additional_kwargs=get_additional_kwargs, do_sample=False,
                            ), 
                            skip_special_tokens=True
                        )

                        # print(avg_logprobs)
                        timer.checkpoint(f"append to results i={i}")
                        for k, (completion_pretrained, completion) in enumerate(zip(completions_pretrained_per_batch[i], completions)):
                            # print(f"{k} {alpha} {logprobs}")
                            example_idx, prompt, target = examples[k]

                            completion = completion[len(prompt):]
                            completion_pretrained = completion_pretrained[len(prompt):]

                            # print(f"interpolated:{completion}\npretrained:{completion_pretrained}\n\n\n")


                            if hasattr(self.dataset, "stop_sequences"):
                                for seq in self.dataset.stop_sequences:
                                    completion = completion.split(seq)[0]
                                    completion_pretrained = completion_pretrained.split(seq)[0]
                                    # print(f"interpolated:{completion}\npretrained:{completion_pretrained}\n\n\n")


                            outs.append({
                                "idx": example_idx,
                                "split": quantifier["split"],
                                "alpha": alpha,
                                "target": target,
                                "completion": completion.strip(),
                                "completion_pretrained": completion_pretrained.strip()
                            })


                        progress_bar.update(j-i)
                        
                    self.save(outs, quantifier, force=True, generations=True, tmp=True, experiment_name=experiment_name, pref=pref)
                self.save(outs, quantifier, force=True, generations=True, tmp=False, experiment_name=experiment_name, pref=pref)

    
    def dump_msj_attributions(
        self, 
        idx=0, 
        n_splits=1,
        grid = [0.75, 0.9, 0.95, 1.0, 1.05, 1.1, 1.25],
    ):
        self.dataset.prepare(self.models)
        
        mode="alpha_scaling"
        experiment_name = f"interpolated_{self.dataset.task_type}_attributions_{mode}"
        pref = f"idx_{idx}_nsplits_{n_splits}__"

        for quantifier in self.dataset.get_quantifiers():
            n_tries = 0
            while n_tries < 10:
                try:
                    print("Number of tries:", n_tries)
                    clean()
                    print("n_prompts", self.args.n_prompts)
                    all_examples0 = self.dataset.get_prompts(quantifier, n_prompts=self.args.n_prompts)
                    size_split = math.ceil(len(all_examples0) / n_splits)
                    all_examples = all_examples0[idx*size_split : min(len(all_examples0), (idx+1)*size_split)]
                    print(f"Got {len(all_examples)} examples")

                    is_complete = self.is_complete_logprobs(
                        quantifier, grid, all_examples, generations=False, experiment_name=experiment_name, pref=pref)
                    
                    if is_complete or len(all_examples) == 0:
                        print(f"!!!!!!!!!!!!!!!!!!!!! {self.dataset.label_of(quantifier)} is complete !!!!!!!!!!!!!!!!!!!!! ")

                    else:
                        outs = []
                        progress_bar = tqdm(total=len(all_examples) * len(grid))
                        for alpha in grid:
                            for i in range(0, len(all_examples), self.args.save_freq):
                                j = min(i + self.args.save_freq, len(all_examples))
                                examples = all_examples[i:j]
                                texts = [p[1] for p in examples]
                                new_outs = self.attributions.compute(
                                    texts,
                                    batch_size=self.args.batch_size,
                                    _external_pbar=progress_bar,
                                    use_cache=False,
                                    callback=self.models.get_callback_shrink(
                                        alpha, 
                                        use_timer=False,
                                        single_gpu=True
                                    )
                                )

                                all_logits = np.zeros((len(new_outs), 1, len(self.models.tokenizer))) - np.inf
                                for j, d in enumerate(new_outs):
                                    tokenwise_df = d["tokenwise_df"]
                                    all_logits[j, 0, tokenwise_df.tok.values] = tokenwise_df.logit

                                all_logits = torch.tensor(all_logits)
                                
                                if type(examples[0][2]) == str:
                                    keys_text = sorted([examples[0][2], examples[0][3]])
                                else:
                                    keys_text = sorted(examples[0][2] + examples[0][3])
                                all_logprobs = get_many_total_probs(SimpleNamespace(logits=all_logits), keys_text, self.models)

                                for (prompt_idx, _, ans_match, ans_not_match), logprobs, d in zip(
                                    examples, 
                                    all_logprobs, 
                                    new_outs
                                ):
                                    d["idx"] = prompt_idx
                                    d["n_shots"] = quantifier["nshots"]
                                    d["label"] = self.dataset.label_of(quantifier)
                                    d["alpha"] = alpha
                                    d["answer_matching_behavior"] = ans_match
                                    d["answer_not_matching_behavior"] = ans_not_match
                                    d["logprobs"] = logprobs

                                outs = outs + new_outs
                                # self.save(outs, quantifier, force=True, tmp=True, experiment_name=experiment_name, pref=pref)
                        self.save(outs, quantifier, force=True, tmp=False, experiment_name=experiment_name, pref=pref)
                    break
                except RuntimeError as e:
                    if "CUDA" in str(e) or "out of memory" in str(e) or "CUBLAS_STATUS_NOT_INITIALIZED" in str(e):
                        clean()
                        new_batch_size = max(1, int(0.75*self.args.batch_size))
                        print("!"*50, f"CUDA out of memory error encountered. Retrying with half batch size: {self.args.batch_size} -> {new_batch_size}", "!"*50 )
                        setattr(self.args, "batch_size", new_batch_size)
                        n_tries += 1
                    else:
                        raise e 
                