import gc
import os, sys
import torch
import torch.nn.functional as F
import pandas as pd
import concurrent
from types import SimpleNamespace
import matplotlib.pyplot as plt
import random
import numpy as np
import json
import time
import transformers
from typing import List
import re
from contextlib import contextmanager

dtypes_dict = {
    "float16": torch.float16,
    "float32": torch.float32,
    "float64": torch.float64,
}

@contextmanager
def no_print():
    # Save a reference to the original sys.stdout
    original_stdout = sys.stdout
    
    try:
        # Replace sys.stdout with a dummy object that does nothing
        sys.stdout = open(os.devnull, 'w')
        yield
    finally:
        # Restore the original sys.stdout
        sys.stdout = original_stdout

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

# Taken from Eleuther's lm-evaluation-harness
class RegexFilter:
    """ """

    def __init__(
        self,
        regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
        group_select=0,
        fallback: str = "[invalid]",
    ) -> None:
        """
        pass a string `regex` to run `re.compile(r"regex")` on.
        `fallback` defines the output returned if no matches for the regex are located.
        """
        self.regex_pattern = regex_pattern
        self.regex = re.compile(regex_pattern)
        self.group_select = group_select
        self.fallback = fallback

    def apply(self, resps, docs):
        # here, we assume we have a list, in which each element is
        # a list of model responses for some particular input/target pair.
        # so we process each of these (same input/target response sets)
        # independently (and keep them a list.)
        def filter_set(inst):
            filtered = []
            for resp in inst:
                match = self.regex.findall(resp)
                if match:
                    match = match[self.group_select]
                    if isinstance(match, tuple):
                        match = [m for m in match if m][0]
                    match = match.strip()
                else:
                    match = self.fallback
                filtered.append(match)
            return filtered

        # print(resps)
        filtered_resps = list(map(lambda x: filter_set(x), resps))
        # print(filtered_resps)

        return filtered_resps

# Taken from Eleuther's lm-evaluation-harness
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
    """Criteria to stop on the specified multi-token sequence."""

    def __init__(
        self,
        sequence: str,
        tokenizer: transformers.PreTrainedTokenizer,
        initial_decoder_input_length: int,
        batch_size: int,
    ) -> None:
        self.initial_decoder_input_length = initial_decoder_input_length
        self.done_tracker = [False] * batch_size
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
        # print(sequence, self.sequence_ids)
        # we look back for 2 more tokens than it takes to encode our stop sequence
        # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
        # and we don't want to mistakenly not stop a generation because our
        # (string) stop sequence was output in a different tokenization

        # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
        # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
        # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
        self.sequence_id_len = len(self.sequence_ids) + 2
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
        lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]

        lookback_ids_batch = lookback_ids_batch[:, max(0, -self.sequence_id_len + 10) :]

        lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)

        for i, done in enumerate(self.done_tracker):
            if not done:
                stop_seq_ocurred = self.sequence in lookback_tokens_batch[i]
                repetition = len(lookback_ids_batch[i]) >= 10 and all([x == lookback_ids_batch[i][-1] for x in lookback_ids_batch[i][-10:]])
                self.done_tracker[i] = stop_seq_ocurred or repetition

        # print(json.dumps(list(zip(self.done_tracker, lookback_tokens_batch)), indent=4))
        ans = False not in self.done_tracker

        # if ans:
        #     print("!!!!!!! STOPPED GENERATION !!!!!!!!!!!")
        return ans
    
def stop_sequences_criteria(
    tokenizer: transformers.PreTrainedTokenizer,
    stop_sequences: List[str],
    initial_decoder_input_length: int,
    batch_size: int,
) -> transformers.StoppingCriteriaList:
    return transformers.StoppingCriteriaList(
        [
            *[
                MultiTokenEOSCriteria(
                    sequence, tokenizer, initial_decoder_input_length, batch_size
                )
                for sequence in stop_sequences
            ],
        ]
    )

def pad(tokens, padding_token):
    # Determine maximum length of the sequence
    max_length = max(len(tokens) for tokens in tokens)
    
    # Create tensor to represent the sequence with padding
    padded_sequence = torch.full((len(tokens), max_length), padding_token)
    for i, tokens in enumerate(tokens):
        padded_sequence[i, -len(tokens):] = torch.tensor(tokens)  # Adding tokens on the left
    
    # Generate attention mask
    attention_mask = (padded_sequence != padding_token).int()
    
    return padded_sequence, attention_mask

class Timer:
    def __init__(self, active=True):
        self.active = active
        self.name = None
        self.t0 = 0

    def checkpoint(self, name):
        if not self.active:
            return
        
        t = time.time()
        if self.name is not None:
            print(f"[{self.name}] time taken: {t - self.t0}s")
        self.name = name
        self.t0 = t

def save_fig(metric, fig, experiment_name, no_formatting_str, sysprompt_str, format, args, dir, pref=""):
    filename = f"{pref}{args.model}_{args.n_params}_{experiment_name}_{no_formatting_str}_{sysprompt_str}_{metric}.{format}"
    path = os.path.join(dir, filename)

    print(f"Received {path} to save")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    fig.tight_layout(pad=1.0)
    fig.savefig(path, format=format, dpi=100, bbox_inches='tight', metadata=None)
    print(f"Saved {path}")
    plt.close(fig)
    return metric, filename

def fix_seed(seed=42):
    # Set random seed for reproducibility
    seed = 42
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

    # If you're using GPU, you should also set the seed for torch.cuda
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # For transformers library
    from transformers import set_seed
    set_seed(seed)

def make_args(model=None, n_params=None, dataset=None, batch_size=5, save_freq=200, 
      n_prompts=1000, dtype='float16', no_formatting=False, use_sysprompt=False, debug=False, **kwargs):
    
    for k,v in kwargs.items():
        print(f"IGNORING {k} : {v} in make_args")

    return SimpleNamespace(
        model=model,
        n_params=n_params,
        dataset=dataset,
        batch_size=batch_size,
        save_freq=save_freq,
        n_prompts=n_prompts,
        dtype=dtype,
        no_formatting=no_formatting,
        use_sysprompt=use_sysprompt,
        debug=debug
    )

PROMPT_TEMPLATE = """<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\'t know the answer to a question, please don\'t share false information.\n<</SYS>>\n\n{}[/INST]\n\n"""
def get_prompt_in_template(prompt,):
    return PROMPT_TEMPLATE.format(prompt)

def clean():
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()

def get_hidden_states(these_texts, out):
    these_hidden_states_finetuned = []
    these_hidden_states_final_finetuned = []

    for i in range(len(out[these_texts[0]]["hidden_states_finetuned"])):
        these_hidden_states_finetuned.append(torch.concat([
            out[text]["hidden_states_finetuned"][i]
            for text in these_texts
        ], dim=0))
        these_hidden_states_final_finetuned.append(torch.concat([
            out[text]["hidden_states_final_finetuned"][i]
            for text in these_texts
        ], dim=0))
    
    return these_hidden_states_finetuned, these_hidden_states_final_finetuned

def _make_one_df(p):
    text, k, attr_outs, itos, out = p
    true_logits = out[text]["true_logits"]
    true_logprobs = F.log_softmax(true_logits, dim=-1)
    true_probs = F.softmax(true_logits, dim=-1)

    topk_idx = torch.topk(true_logits, k=250)[1].tolist()
    
    tokenwise_df = pd.DataFrame([
        {
            "tok": tok,
            "id": itos[tok],
            **{
                key.replace("tokenwise_", ""): value[k, tok].item()
                for key, value in attr_outs.items()
                if "tokenwise_" in key
            },
            "logit": true_logits[tok].item(),
            "logprob": true_logprobs[tok].item(),
            "prob": true_probs[tok].item()
        }
        for tok in topk_idx
    ])

    tokenwise_df = tokenwise_df.sort_values("logit", ascending=False).reset_index(drop=True)

    consolidated_df = pd.DataFrame([{
        "avg_tokenwise_coarse_attr_finetuned": attr_outs["tokenwise_coarse_attr_finetuned"][k, :].mean().item(),
        "avg_tokenwise_finegrained_attr_finetuned": attr_outs["tokenwise_finegrained_attr_finetuned"][k, :].mean().item(),
        "avg_tokenwise_layerwise_attr_finetuned": attr_outs["tokenwise_layerwise_attr_finetuned"][k, :].mean().item(),
        **{   
            key.replace("isotropic_", ""): value[k].item()
            for key,value in attr_outs.items()
            if "isotropic_" in key
        }
    }])
    
    return text, tokenwise_df, consolidated_df

def get_corresponding_tokens(keys_text, models):
    ans = {k: [] for k in keys_text}
    for tok in range(len(models.tokenizer)):
        tok_txt = models.tokenizer.decode(tok)
        for k in keys_text:
            if tok_txt.lower().strip() == k.lower().strip():
                ans[k].append(tok)
    return ans

def get_total_probs(logprobs, keys_to_toks, normalize=True):
    all_toks = [tok for k, toks in keys_to_toks.items() for tok in toks]

    if not normalize:
        shift = max(logprobs[0, all_toks])
        ans = {
            k: (torch.log(torch.exp(logprobs[0, toks] - shift).sum()) + shift).item()
            for k, toks in keys_to_toks.items()
        }
    else:
        shift = max(logprobs[0, all_toks])
        ans = {
            k: torch.exp(logprobs[0, toks] - shift).sum()
            for k, toks in keys_to_toks.items()
        }
        
        total_prob = sum(ans.values())
        ans = {k: torch.log(v/total_prob).item() if total_prob > 1e-12 else 0 for k,v in ans.items()}
    # # print(ans)
    return ans

def get_many_total_probs(out, keys_text, models, **kwargs):
    logprobs_tensor = torch.log_softmax(out.logits[:, -1, :].to(torch.float32), dim=-1)
    keys_to_toks = get_corresponding_tokens(keys_text, models)
    logprobs = [
        get_total_probs(
            logprobs_tensor[i:i+1], 
            keys_to_toks, 
            **kwargs
        ) 
        for i in range(out.logits.shape[0])
    ]

    return logprobs