import sys
import os
import time
from collections import defaultdict
import random

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from gpt2_crypten import GPT2
from gpt2 import GPT2 as PlainGPT2
import torch
from tqdm import tqdm
import nltk
import numpy as np
import torch.nn.functional as F

import crypten
import crypten.communicator as comm
from crypten.config import cfg
from utils import encrypt_tensor, encrypt_model
import os
from presidio_analyzer import AnalyzerEngine
from presidio_analyzer.nlp_engine import NlpEngineProvider
import spacy
import torch.distributed as dist
from einops import repeat
import math

nltk.download('punkt')

def vanilla_copy_weight(src_model_state_dict):
    sd = src_model_state_dict
    config = AutoConfig.from_pretrained(model_name)
    model = PlainGPT2(config).to(torch.float16)
    with torch.no_grad():
        for name, param in model.named_parameters():
            if any(sub in name for sub in ['c_fc', 'c_proj', 'c_attn', 'c_proj']) and 'weight' in name:
                param.copy_(sd.pop('transformer.'+name))
            elif 'lm_head' in name:
                param.copy_(sd.pop(name))
            else:
                param.copy_(sd.pop('transformer.'+name))
    return model

# 2PC setting
rank = sys.argv[1]
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(2)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = sys.argv[2]
os.environ["RENDEZVOUS"] = "env://"

part = int(sys.argv[3])

random.seed(42)
torch.manual_seed(42)

configuration = {
    "nlp_engine_name": "spacy",
    "models": [
        {"lang_code": "en", "model_name": "en_core_web_sm"},
    ],
}

provider = NlpEngineProvider(nlp_configuration=configuration)
nlp_engine = provider.create_engine()

analyzer = AnalyzerEngine(nlp_engine=nlp_engine, supported_languages=["en"])

device = "cuda" if torch.cuda.is_available() else "cpu"

def mask_pii(text, name_list=None, suffix_list=None, n_mask=0):
    entity_type_list = ["PERSON", "URL", "PHONE_NUMBER", "EMAIL_ADDRESS", "LOCATION", "DATE_TIME"]
    results = analyzer.analyze(text=text, language='en', entities=entity_type_list)

    results = sorted(results, key=lambda x: x.start)
    new_results = []

    prev_end = 0
    for idx, result in enumerate(results):
        start, end = result.start, result.end
        if prev_end <= start:
            new_results.append(result)
        prev_end = end
    results = new_results

    natural_entity_list = ["Person", "Website", "Phone", "Email", "Place", "Date"]
    natural_entity_dict = {}
    for entity_idx, entity_type in enumerate(entity_type_list):
        natural_entity_dict[entity_type] = natural_entity_list[entity_idx]

    if name_list is None:
        name_list = {}
        for entity_idx, entity_type in enumerate(entity_type_list):
            name_list[entity_type] = []

    if suffix_list is None:
        suffix_list = []

    cur_suffix_list = []
    for result in results:
        start, end = result.start, result.end
        cur_name = text[start:end].lower()
        if cur_name not in name_list[result.entity_type]:
            name_list[result.entity_type] += [cur_name]
            cur_suffix_list.append(f"[MASK_{natural_entity_dict[result.entity_type].upper()}_{name_list[result.entity_type].index(cur_name)}]=" + text[start:end] + ", ")
    
    suffix_list += cur_suffix_list

    results = results[::-1]
    masked_text = text
    for idx, result in enumerate(results):
        start, end = result.start, result.end
        cur_name = text[start:end].lower()
        masked_text = masked_text[:start] + f"[MASK_{natural_entity_dict[result.entity_type].upper()}_{name_list[result.entity_type].index(cur_name)}]" + masked_text[end:]    
    n_mask = len(cur_suffix_list)+n_mask

    return masked_text, suffix_list, name_list, n_mask
    
crypten.init()
cfg.communicator.verbose = True
commInit = crypten.communicator.get().get_communication_stats()

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
m = GPT2(config)
model = encrypt_model(m, GPT2, config).eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model.config.pad_token_id = tokenizer.pad_token_id

if int(rank) == 1:
    plain_model = PlainGPT2(config).to(torch.float16).eval().to(device)
elif int(rank) == 0:
    src_model_state_dict = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).state_dict()
    plain_model = vanilla_copy_weight(src_model_state_dict).eval().to(device)
    del src_model_state_dict
plain_model.config.pad_token_id = tokenizer.pad_token_id

choices = [" A", " B", " C", " D"]
choices_idx_list = []
for choice in choices:
    input_ids = tokenizer(choice, return_tensors="pt")['input_ids']
    choices_idx_list.append(input_ids.reshape(-1).long().item())

dataset = load_dataset("cais/mmlu", "all", split="test")

def compute_ppi_mask(masked_text, n_mask, suffix_list):

    encoding = tokenizer(masked_text, return_offsets_mapping=True)

    total_len = len(encoding["input_ids"])
    offset = encoding["offset_mapping"]
    token_list = [tokenizer.decode([ids]) for ids in encoding["input_ids"]]

    offset_start_list = []
    offset_end_list = []
    for (start, end) in offset:
        offset_start_list.append(start)
        offset_end_list.append(end)
    offset_start_list = torch.tensor(offset_start_list)
    offset_end_list = torch.tensor(offset_end_list)

    assert len(masked_text.split("Masked information:")) == 3
    end_len = sum([len(text_ele) for text_ele in masked_text.split("Masked information:")[:2]]) + len("Masked information:")
    token_end_len = (offset_start_list < end_len).sum().item()

    ppi_mask_list = []
    pos_id_list = []
    prev_end = total_len
    for idx in range(n_mask):
        ppi_mask = torch.ones(total_len, dtype=torch.int32)
        if suffix_list[idx].split("=")[0] == "[MASK_PLACE_0]":
            pre_len = sum([len(text_ele) for text_ele in masked_text.split(suffix_list[idx].split("=")[0])[:3]]) + 2*len(suffix_list[idx].split("=")[0])
        else:
            pre_len = len(masked_text.split(suffix_list[idx].split("=")[0])[0])

        token_pre_len = (offset_end_list - pre_len <= 0).sum().item()
        ppi_mask[token_pre_len:token_end_len] = 0

        suffix = suffix_list[idx]
        suffix_adjust = " " + suffix[:-2]

        suffix_len = len(suffix_adjust)
        suffix_pre_len = len(masked_text.split(suffix_adjust)[0])

        start_len = (offset_start_list - suffix_pre_len < 0).sum().item()
        end_len = (offset_end_list - suffix_pre_len-suffix_len <= 0).sum().item()

        n_token = end_len-start_len

        ppi_mask = repeat(ppi_mask, 'n -> k n', k=n_token)
        ppi_mask_list.append(ppi_mask)

        pos_id = torch.arange(token_pre_len,token_pre_len+n_token)
        pos_id_list.append(pos_id)

        if prev_end < start_len:
            ppi_mask_list.append(torch.ones((start_len-prev_end, total_len), dtype=torch.int32))
            pos_id_list.append(torch.arange(prev_end,start_len))
            
        if idx == 0:
            initial_start_pt = start_len

        prev_end = end_len

    ppi_mask_list = torch.cat(ppi_mask_list, dim=0)
    end_ppi_mask = torch.ones((total_len-prev_end, total_len), dtype=torch.int32)
    ppi_mask_list = torch.cat((ppi_mask_list,end_ppi_mask), dim=0)

    return ppi_mask_list


for subject in ["professional_medicine"]:
    subject_dataset = dataset.filter(lambda x: x["subject"] == subject)
    if subject == "professional_law":
        subject_dataset = subject_dataset.shuffle(seed=42).select(range(200))

    total_time_list = []
    comm_time_list = []
    comm_byte_list = []
    cors = []
    all_probs = []
    for d_idx, item in tqdm(enumerate(subject_dataset), desc=f"Testing on {subject}"):

        t0 = time.time()
        comm0 = comm.get().get_communication_stats()

        example_prompt = "Question 1. Choose the body cavity containing the [MASK_PLACE_0]."
        example_choices = ["Abdominal", "Cranial", "Pleural", "Spinal"]
        for j in range(4):
            example_prompt += "\n{}) {}".format(choices[j], example_choices[j])
        example_prompt += "\nMasked information: [MASK_PLACE_0]=pituitary gland."
        example_prompt += "\nAnswer:"
        example_prompt += " B\n\n"

        question = item["question"]
        input_ids = tokenizer(question, return_tensors="pt", truncation=True, max_length=640)['input_ids']
        if input_ids.size(1) == 640:
            print("question length is over 640")
            truncated_question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        else:
            truncated_question = question
        masked_input_text, suffix_list, name_list, n_mask = mask_pii(truncated_question)
        prompt = example_prompt + "Question 2. " + masked_input_text

        all_choice_prompt = ""
        for j in range(len(item["choices"])):
            choice_prompt = item["choices"][j]
            input_ids = tokenizer(choice_prompt, return_tensors="pt", truncation=True, max_length=32)['input_ids']
            if input_ids.size(1) == 32:
                print("choice length is over 32")
                truncated_choice = tokenizer.decode(input_ids[0], skip_special_tokens=True)
            else:
                truncated_choice = choice_prompt
            masked_choice_prompt, suffix_list, name_list, n_mask = mask_pii(truncated_choice, name_list, suffix_list, n_mask)
            all_choice_prompt += "\n{}) {}".format(choices[j], masked_choice_prompt)
        prompt += all_choice_prompt

        if n_mask == 0:
            continue

        suffix_text = "".join(suffix_list)
        suffix_text = suffix_text[:-2] + '.'
        prompt += f"\nMasked information:"

        prompt_remainder = " " + suffix_text
        prompt_remainder += "\nAnswer:"

        n_skip = 1
        if n_mask > n_skip:
            ppi_mask_list = compute_ppi_mask(prompt+prompt_remainder, n_mask-n_skip, suffix_list)
        else:
            ppi_mask_list = None

        input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).input_ids.to(device)
        n_non_private = input_ids.size(1)

        with torch.no_grad():
            _, kv_cache_list, _ = plain_model.generate(
                input_ids, 
                max_new_tokens=1,
                attention_mask=None,
            )

        kv_cache = []
        for layer_idx in range(len(kv_cache_list)):
            kv_cache.append([crypten.cryptensor(kv_cache_list[layer_idx][0].to(torch.float32), src=0), crypten.cryptensor(kv_cache_list[layer_idx][1].to(torch.float32), src=0)])
        del kv_cache_list

        label = item["answer"]

        input_ids = F.one_hot(input_ids.to(device), config.vocab_size)
        input_ids = encrypt_tensor(input_ids)

        with crypten.no_grad():
            _, _, ans_probs = model.generate(
                input_ids,
                max_new_tokens=1,
                target_ids=choices_idx_list,
                kv_cache=kv_cache, 
                attn_mask = ppi_mask_list,
            )  

        comm1 = comm.get().get_communication_stats()
        t1 = time.time()
        cur_total_time = t1-t0
        cur_comm_time = comm1["time"] - comm0["time"]
        cur_comm_byte = comm1["bytes"] - comm0["bytes"]

        comm_time_list.append(cur_comm_time)
        comm_byte_list.append(cur_comm_byte)
        total_time_list.append(cur_total_time)

        lprobs = []
        for c_idx, choice_idx in enumerate(choices_idx_list):
            lprobs.append(ans_probs[0,0,c_idx].item())

        pred = np.argmax(lprobs)
        cor = pred == label
        cors.append(cor)

    acc = np.mean(cors)
    cors = np.array(cors)

    total_time = np.mean(total_time_list)
    comm_time = np.mean(comm_time_list)
    comm_byte = np.mean(comm_byte_list)

    print("Average accuracy {:.4f} - {}".format(acc, subject))
    print("Average time: {:.4f}".format(total_time))
    print("Average comm. time: {:.4f}".format(comm_time))
    print("Average comm. bytes: {:.4f}".format(comm_byte))