import os
import numpy as np
import torch
import string
import torch.nn.functional as F
import random 
import itertools
from graph_data_utils import Dataset

 
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

 

class PhoneBookTokenizer:
    def __init__(self, TO_TOKEN, TO_CHAR):

        self.TO_TOKEN = TO_TOKEN
        self.TO_CHAR = TO_CHAR
        
        self.bos_token_id = TO_TOKEN['$']
        self.eos_token_id = TO_TOKEN['.']
        self.pad_token_id = TO_TOKEN['*']

    def __call__(self, x):
        encoded = [self.TO_TOKEN[c] for c in x]
        return encoded

    def decode(self, x):
        if torch.is_tensor(x):
            x = x.detach().cpu().numpy()
        decoded = ''.join([str(t) if t not in self.TO_CHAR else self.TO_CHAR[t] for t in x])
        return decoded

    def __len__(self):
        return len(self.TO_TOKEN)



def get_phonebook_tokenizer(args):

    numbers = [str(i) for i in range(10)]#range(args.max_train_nodes+1)]
    letters = dict(zip(numbers, range(10)))
    
    symbols = {'$': len(letters), '.':len(letters)+1,   '|': len(letters)+2, '*': len(letters)+3}
    
    string_ascii_lowercase = string.ascii_lowercase
    
    str_low = dict(zip(string_ascii_lowercase, range(len(letters)+len(symbols),len(letters)+len(symbols)+len(string_ascii_lowercase))))
    
    TO_TOKEN = {**letters, **symbols, **str_low}
    TO_CHAR = {v:k for k,v in TO_TOKEN.items()}
    tokenizer = PhoneBookTokenizer(TO_TOKEN, TO_CHAR)
    
 
    return  tokenizer, TO_TOKEN, TO_CHAR


def arr_to_str_sp(x):
    return '-'.join([str(n) for n in x])

def arr_to_str(x):
    return ''.join([str(n) for n in x])

  


def create_phonebook_to_save(model_name, tokenizer, num_examples_train, num_examples_test, sequence_length, label_sequence_length, seed):

    set_seed(seed) 
    #train_batch = {'input': [], 'input_ids': [], 'mask': [], 'label_ids': []} if self.model_name == "gmlp" else {'input': [], 'input_ids': [], 'mask': []}
    
    print("START GENERATE PEOPLE NAMES + NUMS",flush=True)
    ppl_names = [''.join(i) for i in itertools.product(string.ascii_lowercase, repeat = 5)]#4)]
    numbers = [str(i) for i in range(10)]
    pho = [''.join(i) for i in itertools.product(numbers, repeat = 8)]
    print(f"ORIGIN PPL NAMES {len(ppl_names)}",flush=True)
    print(f"ORIGIN PHO {len(pho)}",flush=True)
    print("DONE GENERATE PEOPLE NAMES + NUMS",flush=True)
    random.shuffle(pho)


    train_data = {}
    val_data = {}
    test_data = {}
    input_list = []
    input_ids_list = []
    if model_name == "gmlp":
        label_ids_list = []
    else:
        mask_list = []

    length_list = []
    phonebook_names = []
    phonebook_numbers = []

    counter = 0
    while counter < num_examples_train:
        #print(f"IDX {counter}",flush=True)
        name = ppl_names[counter]
        phone_number = pho[counter]
        if model_name == "gmlp":
            example_str = f"{name}"
            label_str = f"{phone_number}"
            example_ids = tokenizer(example_str)
            label_ids = tokenizer(label_str)
        else:
            example_str = f"${name}|{phone_number}."
            example_ids = tokenizer(example_str)
            example_mask = [1] * len(example_ids) + [0] * (sequence_length - len(example_ids))


        full_tokenized = example_ids + [tokenizer.TO_TOKEN["*"]] * (sequence_length - len(example_ids))
        full_str = tokenizer.decode(full_tokenized)
        input_list.append(full_str)
        input_ids_list.append(full_tokenized)#torch.tensor(full_tokenized, dtype=torch.int64))  
        if model_name == "gmlp":
            label_ids_list.append(label_ids)#torch.tensor(label_ids, dtype=torch.int64))
        else:
            assert len(full_tokenized) == len(example_mask)
            example_mask = example_mask#torch.tensor(example_mask)
            mask_list.append(example_mask)
        
        phonebook_names.append(name)
        phonebook_numbers.append(phone_number) 
        counter+=1

    
    assert len(input_list) == num_examples_train
    assert len(input_ids_list) == num_examples_train

    if model_name == "gmlp":
        c = list(zip(input_list, input_ids_list, label_ids_list))
        random.shuffle(c)
        input_list, input_ids_list, label_ids_list = zip(*c)
    else:
        c = list(zip(input_list, input_ids_list, mask_list))
        random.shuffle(c)
        input_list, input_ids_list, mask_list = zip(*c)

    train_data['input'] = input_list 
    train_data['input_ids']= input_ids_list 

    test_indices = random.sample(range(num_examples_train),num_examples_test)
    
    test_data['input'] = [input_list[i] for i in range(num_examples_train) if i in test_indices]
    test_data['input_ids'] = [input_ids_list[i] for i in range(num_examples_train) if i in test_indices]

    assert len(test_data['input']) == num_examples_test
    assert len(test_data['input_ids']) == num_examples_test


    if model_name == "gmlp":
        train_data['label_ids'] = label_ids_list 
        test_data['label_ids'] = [label_ids_list[i] for i in range(num_examples_train) if i in test_indices]
        assert len(test_data['label_ids']) == num_examples_test
    else:
        train_data['mask']  = mask_list 
        test_data['mask'] = [mask_list[i] for i in range(num_examples_train) if i in test_indices]
        assert len(test_data['mask']) == num_examples_test
    
    return train_data, test_data


 
def get_phonebook_to_save(args):
    

    numbers = [str(i) for i in range(10)]#range(args.max_train_nodes+1)]
    letters = dict(zip(numbers, range(10)))
    
    symbols = {'$': len(letters), '.':len(letters)+1,   '|': len(letters)+2, '*': len(letters)+3}
    
    string_ascii_lowercase = string.ascii_lowercase
    
    str_low = dict(zip(string_ascii_lowercase, range(len(letters)+len(symbols),len(letters)+len(symbols)+len(string_ascii_lowercase))))
    
    TO_TOKEN = {**letters, **symbols, **str_low}
    TO_CHAR = {v:k for k,v in TO_TOKEN.items()}
    #print(f"TO TOKEN:\n{TO_TOKEN}\n\n\n")
    tokenizer = PhoneBookTokenizer(TO_TOKEN, TO_CHAR)
    

     
    train_data, test_data = create_phonebook_to_save(args.model, tokenizer, args.num_examples_train, args.num_examples_test, args.sequence_length, args.label_sequence_length, args.seed)

    
    return train_data,  test_data 



