import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import pandas as pd
import tqdm
import gc

def getresult(model,inputpath,stop_words):
    fullresults = []
    with open(inputpath) as f:
        for line in tqdm.tqdm(f):
            line = json.loads(line)
            prompt = line['prompt']
            index = len(model.tokenizer.encode(prompt))
            text = model.generate_text(prompt,index+150)
            min_stop_idx = len(text)
            for stop_word in stop_words:
                stop_index = text.find(stop_word)
                if 0 <= stop_index and stop_index < min_stop_idx:
                    min_stop_idx = stop_index + len(stop_word)
            text = text[:min_stop_idx]
            results = model.logits_all_layers(text,index-1)
            fullresults.append(results)
    return fullresults

def getworklang(fullresults, last=-1):
    py = 0
    tgt = 0
    total = 0
    for results in fullresults:
        for i in range(len(results[0])):
            if results[-1][i][0][0].strip() in functions[args.lang]:
                for layer in results[:last]:
                    layer = layer[i]
                    flag = False
                    if layer[0][0].strip() in functions[args.lang][results[-1][i][0][0].strip()]:
                        py += 1
                        total += 1
                        flag = True
                    if not flag and layer[0][0].strip() == results[-1][i][0][0].strip():
                        tgt += 1
                        total += 1
        for i in range(len(results[0])-1):
            if len(results[-1][i][0][0].strip()) > 0 and len(results[-1][i+1][0][0].strip()) > 0 and results[-1][i][0][0].strip()+results[-1][i+1][0][0].strip() in functions[args.lang]:
                for layer in results[:last]:
                    layer = layer[i]
                    flag = False
                    if layer[0][0].strip() in functions[args.lang][results[-1][i][0][0].strip()+results[-1][i+1][0][0].strip()]:
                        py += 1
                        total += 1
                        flag = True
                    if not flag and layer[0][0].strip() == results[-1][i][0][0].strip():
                        tgt += 1
                        total += 1
    return py, tgt, total

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm
        self.block_output_unembedded = None


    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
        return output

class Llama7BHelper:
    def __init__(self, plm):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(plm,trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(plm,trust_remote_code=True).to(self.device).eval()
        for i, layer in enumerate(self.model.transformer.h):
            self.model.transformer.h[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.transformer.ln_f)

    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(input_ids=inputs.input_ids.to(self.device), attention_mask = inputs.attention_mask.to(self.device), do_sample=False, max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits
    
    def logits_all_layers(self, text, index):
        self.get_logits(text)
        results = []
        for i, layer in enumerate(self.model.transformer.h):
            decoded_activations = layer.block_output_unembedded.detach()
            endindex = len(decoded_activations[0])
            layerresult = []
            for i in range(index,endindex):
                softmaxed = torch.nn.functional.softmax(decoded_activations[0][i], dim=-1)
                values, indices = torch.topk(softmaxed, 1)
                probs_percent = [int(v * 100) for v in values.tolist()]
                tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
                result = list(zip(tokens, probs_percent))
                layerresult.append(result)
            results.append(layerresult)
        return results

import argparse, json, os

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--output_path", type=str)
    parser.add_argument("--lang", type=str)
    
    

    args = parser.parse_args()
    

    functions = json.load(open('functions.json'))
    
    output_path = args.model_path + "_result"
    
    inputpath = f'questions/{args.lang}.json'

    os.makedirs(output_path, exist_ok=True)

    stop_words = []
    if args.lang == 'php' or args.lang == 'cpp':
        stop_words = ['\n}']
    finalresults = []
    outputpath = output_path + '/worklangs.json'
    model = Llama7BHelper(args.model_path)

    

    with open(outputpath,'w') as f:
        results = getresult(model,inputpath,stop_words)
        py, tgt, total = getworklang(results, last=-5)
        if total > 0:
            pyrate = py/total
            tgtrate = tgt/total
        else:
            pyrate = 0
            tgtrate = 0
        finalresults.append([step, py, tgt, total, pyrate, tgtrate])
        f.write(json.dumps(results,ensure_ascii=False)+'\n')
    
    print(pyrate,tgtrate)
    del model
    gc.collect()
    torch.cuda.empty_cache()

    df = pd.DataFrame(finalresults)
    df.to_csv(output_path + '/result.csv')

