import os
import numpy as np
import pandas as pd
import torch
import argparse
from llm_brain_asym import compute_logprobs, make_dir
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='model name')
parser.add_argument('--revision', type=str, help='details of the revision')
parser.add_argument('--device', type=str, default='cuda', help='device (cuda or cpu)')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--seed', type=int, default=12345, help='random seed for the creation of the dataset')
parser.add_argument('--output_folder', type=str, default='dyck', help='output folder')

args = parser.parse_args()

model_name = args.model
revision =  args.revision
device =  args.device
batch_size = args.batch_size
seed = args.seed
output_folder = args.output_folder

torch.set_grad_enabled(False)
np.random.seed(seed)
torch.manual_seed(seed)

make_dir(output_folder)

seq_len = 32
n_examples = 1024
p_open = 0.25
p_shuffle = 0.2

# p_open: probability of open a new parenthesis
def generate_dyck_sequence(n, seq_len, p):
    open_par = np.arange(n)
    close_par = np.arange(n, 2*n)
    # keep track of the last opened parenthesis
    stack_par = []
    seq = []
    for i in range(seq_len):
        if len(stack_par) == 0:
            # nothing to close, need to open a new parenthesis
            new_par = np.random.choice(open_par)
            # put it on top of the stack
            stack_par.append(new_par)
            # add the new one
            seq.append(new_par)
        else:
            # there is at least one parenthesis to close
            # with probability p open a new parenthesis
            prob = np.random.rand()
            if prob <= p:
                # open a new parenthesis
                new_par = np.random.choice(open_par)
                # put it on top of the stack
                stack_par.append(new_par)
                # add the new one
                seq.append(new_par)
            else:
                # close the last opened parenthesis
                par_to_close = stack_par.pop()
                new_par = close_par[par_to_close]
                # add it
                seq.append(new_par)
    return np.array(seq)

def s2str_paren(s, n=1):
    if n == 1:
        paren = ['(', ')']
    elif n == 2:
        paren = ['(', '[', ')', ']']
    elif n == 3:
        paren = ['(', '[', '{', ')', ']', '}']
    else:
        raise Exception('Unsupported number of parentheses (max is 3)')
    return paren[s]

def seq2seqofstr(seq, s2str=s2str_paren, n=1):
    return ' ' + ' '.join([s2str(s, n) for s in seq])

def local_shuffling(list_original, p=0.3):
    # p: probability of swaping two elements
    list_shuffled_local = list_original.copy()
    for i in range(len(list_original)-1):
        if np.random.rand() <= p: # swap
            tmp = list_shuffled_local[i]
            list_shuffled_local[i] = list_shuffled_local[i+1]
            list_shuffled_local[i+1] = tmp
    return list_shuffled_local

def create_dyck_corpus(n_examples, n, seq_len, p_open, p_shuffle):
    prompts_paren = []
    generations_paren = []
    generations_paren_shuffled = []
    for _ in range(n_examples):
        seq = generate_dyck_sequence(n, seq_len, p=p_open)
        prompt, generation = seq[:seq_len//2], seq[seq_len//2:]
        generation_shuffled = local_shuffling(generation, p=p_shuffle)    
        prompts_paren.append(seq2seqofstr(prompt, s2str=s2str_paren, n=n))
        generations_paren.append(seq2seqofstr(generation, s2str=s2str_paren, n=n))
        generations_paren_shuffled.append(seq2seqofstr(generation_shuffled, s2str=s2str_paren, n=n))
    sentences_paren = [prompt + generation for prompt, generation in zip(prompts_paren, generations_paren)]
    sentences_paren_shuffled = [prompt + generation for prompt, generation in zip(prompts_paren, generations_paren_shuffled)]
    return sentences_paren, sentences_paren_shuffled

sentences_dyck1, sentences_dyck1_shuffled = create_dyck_corpus(n_examples, 1, seq_len, p_open, p_shuffle)
sentences_dyck2, sentences_dyck2_shuffled = create_dyck_corpus(n_examples, 2, seq_len, p_open, p_shuffle)
sentences_dyck3, sentences_dyck3_shuffled = create_dyck_corpus(n_examples, 3, seq_len, p_open, p_shuffle)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, revision=revision, torch_dtype=torch.float16).to(device)

# dyck 1
logprobs_paren_good = compute_logprobs(model, tokenizer, sentences_dyck1, batch_size=batch_size)
logprobs_paren_bad = compute_logprobs(model, tokenizer, sentences_dyck1_shuffled, batch_size=batch_size)
accuracy_dyck1 = np.mean(logprobs_paren_good > logprobs_paren_bad)
# dyck 2
logprobs_paren_good = compute_logprobs(model, tokenizer, sentences_dyck2, batch_size=batch_size)
logprobs_paren_bad = compute_logprobs(model, tokenizer, sentences_dyck2_shuffled, batch_size=batch_size)
accuracy_dyck2 = np.mean(logprobs_paren_good > logprobs_paren_bad)
# dyck 3
logprobs_paren_good = compute_logprobs(model, tokenizer, sentences_dyck3, batch_size=batch_size)
logprobs_paren_bad = compute_logprobs(model, tokenizer, sentences_dyck3_shuffled, batch_size=batch_size)
accuracy_dyck3 = np.mean(logprobs_paren_good > logprobs_paren_bad)

df_results = pd.DataFrame([{'model':model_name,
                            'revision':revision,
                            'accuracy':(accuracy_dyck1+accuracy_dyck2+accuracy_dyck3)/3,
                            'accuracy_dyck1':accuracy_dyck1,
                            'accuracy_dyck2':accuracy_dyck2,
                            'accuracy_dyck3':accuracy_dyck3}])
filename = os.path.join(output_folder, '{}_{}_dyck.csv'.format(model_name.replace('/','_'), revision))
df_results.to_csv(filename)