import re
import string

from scipy import stats
import os
import copy
import json
import pandas as pd
import time
import numpy as np
# from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
import gc
from utils.utils import upload_blob_from_memory
from datetime import datetime
from pathlib import Path

if os.uname()[0] != 'Darwin':  # if not on mac
    # from auto_gptq import exllama_set_max_input_length
    from accelerate.utils import release_memory
    import gc

    device_name = 'cuda'


    def flush():
        # try:
        #     model.cpu()
        #     del model
        # except:
        #     pass
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        # release_memory()
else:
    device_name = 'mps'


    def flush():
        pass


def format_messages(messages, toker, sample_config):
    template_name = None
    if len(re.findall('Mistral-.*-Instruct', sample_config.model_name_hf)) > 0:
        template_name = 'mistral-instruct'
    if len(re.findall('Orca', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('.*gemma.*-it.*', sample_config.model_name_hf)) > 0:
        template_name = 'gemma-it'
    if len(re.findall('.*CausalLM.*', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('vicuna', sample_config.model_name_hf)) > 0:
        template_name = 'vicuna'
    if len(re.findall('Llama-.*-chat', sample_config.model_name_hf)) > 0:
        template_name = 'llama-2-chat'
    if len(re.findall('llama2_.*_uncensored', sample_config.model_name_hf)) > 0:
        template_name = 'llama-2-uncensored'
    if len(re.findall('.*Qwen.*-Chat.*', sample_config.model_name_hf)) > 0:
        template_name = 'chatml'
    if len(re.findall('gpt-', sample_config.model_name_hf)) > 0:
        template_name = 'gpt'

    if (template_name is not None) and (template_name != 'llama-2-uncensored') and (template_name != 'gpt'):
        chat_template = open(
            sample_config._paths.sub_path + 'ch_temp/chat_templates/' + template_name + '.jinja').read()
        chat_template = chat_template.replace('    ', '').replace('\n', '')
        toker.chat_template = chat_template
        messages_formatted = toker.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        if template_name != 'mistral-instruct':
            messages_formatted = re.sub('<s>', '', messages_formatted)
            if template_name != 'vicuna':
                messages_formatted = re.sub('</s>', '', messages_formatted)

    else:
        if template_name == 'llama-2-uncensored':
            role_map = {'system': '### HUMAN:\n', 'user': '\n### HUMAN:\n', 'assistant': '\n### RESPONSE:\n'}
            messages_formatted = ''.join([role_map[msg['role']] + msg['content'] for msg in messages])
        elif template_name == 'gpt':
            messages_formatted = messages
        else:
            role_map = {'system': '\n', 'user': '\n', 'assistant': '\n'}
            messages_formatted = ''.join([msg['content'] + role_map[msg['role']] for msg in messages])

    return messages_formatted

def setup_model(sample_config):
    tokenizer = AutoTokenizer.from_pretrained(sample_config.model_name_hf, use_fast=True, trust_remote_code=True)
    if sample_config.model_name not in ['gemma-2-2b-it']:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map=device_name,
                                                 trust_remote_code=True, revision=sample_config.revision,
                                                 attn_implementation='eager', torch_dtype=torch.bfloat16)
    model.eval()
    # model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map="auto",
    #                                              trust_remote_code=False, revision=sample_config.revision,
    #                                              torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    # model = AutoModelForCausalLM.from_pretrained(sample_config.model_name_hf, device_map="auto",
    #                                              trust_remote_code=False, revision=sample_config.revision,
    if sample_config.model_name not in ['gemma-2-2b-it']:
        model.generation_config.pad_token_id = model.generation_config.eos_token_id
    # if os.uname()[0] != 'Darwin':  # if not on mac
    #     model = exllama_set_max_input_length(model, 32000)
    # attn_implementation="flash_attention_2")
    return model, tokenizer


def texts_to_msg(text_data, sample_config, toker,rem_tag=False):
    store_text = {}
    for sub in text_data['sub'].unique():
        sub_text = text_data['text'][text_data['sub'] == sub].values[0]
        sub_message = [{'role': 'user', 'content': sub_text}]
        sub_formatted_message = format_messages(sub_message, toker, sample_config)
        if rem_tag:
            if sample_config.model_name_sshort =='MistralOo':
                # sub_formatted_message.replace('<|im_end|>\n<|im_start|>assistant',)
                sub_formatted_message = re.sub(r"<\|im_end\|>\n<\|im_start\|>assistant\n",'',sub_formatted_message)
            if 'gemma2' in sample_config.model_name_sshort:
                # sub_formatted_message.replace('<|im_end|>\n<|im_start|>assistant',)
                sub_formatted_message = re.sub(r"<end_of_turn>\n<start_of_turn>model\n", '', sub_formatted_message)

        store_text[sub] = sub_formatted_message


    return store_text


def forward_pass_whs(inputs, model, tokenizer, sample_config):
    inputs_ids = tokenizer(inputs, padding=True, return_tensors='pt', return_attention_mask=True).to(device_name)
    loc_idxs = [np.linspace(0, inputs_ids.attention_mask.sum(axis=1).detach().cpu().numpy()[ii]-1, sample_config.hsT, dtype=int)
                for ii
                in range(len(inputs_ids.input_ids.cpu().detach()))]
    # loc_idxs = [np.linspace(0, inputs_ids.attention_mask.sum(axis=1).detach().cpu().numpy()[ii], sample_config.hsT, dtype=int)
    #             for ii
    #             in range(len(inputs_ids.input_ids))]
    out_ids = model(inputs_ids.input_ids, attention_mask=inputs_ids.attention_mask, output_hidden_states=True)
    hs_ts = out_ids.hidden_states
    layer_idxs = np.arange(sample_config.L // 2, sample_config.L)

    # hs_ts = torch.stack(hs_ts)
    # hs_ts = [hs_ts[:,ii,att_mask.detach().cpu() == 1,:][loc_idx,:] for ii, (att_mask, loc_idx) in enumerate(zip(inputs_ids.attention_mask, loc_idxs))]
    hs_ts_new = torch.stack([torch.stack(
        [hs_ts_l[ii, att_mask.detach().cpu() == 1, :][loc_idx, :].detach().cpu() for ii, (att_mask, loc_idx) in
         enumerate(zip(inputs_ids.attention_mask, loc_idxs))]) for hs_ts_l in hs_ts]).permute(1, 0, 2,3)[:,1:][:,layer_idxs]

    del hs_ts

    # extract hidden states for selected layers, at selected timepoints for each element in batch
    # hs_ts = torch.stack([torch.stack(
    #     [out_ids.hidden_states[layer][ii, att_mask.detach().cpu() == 1, :][loc_idx, :].detach().cpu() for ii, (att_mask, loc_idx) in
    #      enumerate(zip(inputs_ids.attention_mask, loc_idxs))]) for layer in sample_config.layer_list]).permute(1, 0, 2,
    #                                                                                                            3)
    # hs_ts = torch.stack([torch.stack(
    #     [out_ids.hidden_states[layer][ii, att_mask.detach().cpu() == 1, :][loc_idx, :].detach().cpu() for ii, (att_mask, loc_idx) in
    #      enumerate(zip(inputs_ids.attention_mask, loc_idxs))]) for layer in sample_config.layer_list]).permute(1, 0, 2,
    #                                                                                                            3)

    return hs_ts_new
