import os
import numpy as np
import torch
import string
import torch.nn.functional as F
import random 
import networkx as nx

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def simplify_list(lst):
    g = []
    temp = ''
    for i in lst:
        if i.isdigit():
            temp += i
        else:
            if temp:
                g.append(temp)
                temp = ''
            g.append(i)
    g.append(i)
    return g[:-1]

class GraphTokenizer:
    def __init__(self, TO_TOKEN, TO_CHAR):

        self.TO_TOKEN = TO_TOKEN
        self.TO_CHAR = TO_CHAR

        self.bos_token_id = TO_TOKEN['$']
        self.eos_token_id = TO_TOKEN['.']
        self.pad_token_id = TO_TOKEN['*']

    def __call__(self, x):
        # print(f"BEFORE SIMPLIFY\n{x}\n\n")
        str_ = simplify_list(list(x))
        # print(f"SIMPLIFY\n{str_}\n\n")
        encoded = [self.TO_TOKEN[c] for c in str_]
        return encoded

    def decode(self, x):
        if torch.is_tensor(x):
            x = x.detach().cpu().numpy()
        decoded = ''.join([str(t) if t not in self.TO_CHAR else self.TO_CHAR[t] for t in x])
        return decoded

    def __len__(self):
        return len(self.TO_TOKEN)


def get_graph_tokenizer(args):
    
    args.directed = bool(args.directed)

    numbers = [str(i) for i in range(args.num_nodes+1)]
    letters = dict(zip(numbers, range(args.num_nodes+1)))
    
    symbols = {'$': len(letters), '-':  len(letters)+1, '>': len(letters)+2, '^': len(letters)+3, '/': len(letters)+4, ',':len(letters)+5,  '|': len(letters)+6, '.': len(letters)+7, '*': len(letters)+8}
    TO_TOKEN = {**letters, **symbols}
    TO_CHAR = {v:k for k,v in TO_TOKEN.items()}
    #print(f"TO TOKEN:\n{TO_TOKEN}")
    tokenizer = GraphTokenizer(TO_TOKEN, TO_CHAR)
 
    return tokenizer, TO_TOKEN, TO_CHAR


def arr_to_str_sp(x):
    return '-'.join([str(n) for n in x])

def arr_to_str(x):
    return ''.join([str(n) for n in x])
    

class Dataset:
    def __init__(self, dataset, model_name):
        self.model_name = model_name
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset['input'])

    def __getitem__(self, idx):

        if self.model_name == "gmlp":
            sample = {
                    'input': self.dataset['input'][idx],
                    'input_ids': self.dataset['input_ids'][idx],
                    'label_ids': self.dataset['label_ids'][idx]
                    }

        else:
            sample = {
                    'input': self.dataset['input'][idx],
                    'input_ids': self.dataset['input_ids'][idx],
                    'mask': self.dataset['mask'][idx],
                    }
        return sample



  

def create_graph_dataset_to_save(model_name, tokenizer, num_examples_train, num_examples_val, num_examples_test, num_nodes,p_edge, directed, sequence_length, label_sequence_length, seed):

    set_seed(seed) 
    #train_batch = {'input': [], 'input_ids': [], 'mask': [], 'label_ids': []} if self.model_name == "gmlp" else {'input': [], 'input_ids': [], 'mask': []}
    

    train_data = {}
    val_data = {}
    test_data = {}
    input_list = []
    input_ids_list = []
    if model_name == "gmlp":
        label_ids_list = []
    else:
        mask_list = []

    length_list = []

    counter = 0
    while counter < num_examples_train+num_examples_val+num_examples_test:
        graph = nx.erdos_renyi_graph(num_nodes, p_edge, directed=directed)
        edges = graph.edges()
        start, end = random.sample(range(num_nodes), 2)
        try:
            short_path = nx.shortest_path(graph, start, end)
            short_path = [sp+1 for sp in short_path]
        except Exception as e:
            short_path = [0]
        label = arr_to_str_sp(short_path)
        length_list.append(len(short_path))
    

        sequence = []
        for (a, b) in edges:
            sequence += [a+1,"-", b+1, ","]
        sequence[-1] = "/"
        sequence += [start+1,">", end+1]
        sequence = arr_to_str(sequence)

        if model_name == "gmlp":
            example_str = f'{sequence}'
            #print(f"EXAMPLE STR\n{example_str}")
            example_ids = tokenizer(example_str)
            label_ids = tokenizer(label)
            example_ids = tokenizer(example_str)
            if len(label_ids) <  label_sequence_length:
                label_ids = label_ids + [tokenizer.TO_TOKEN['*']] * (label_sequence_length - len(label_ids))
        else:
            sequence =  "$"+sequence
            example_str = f'{sequence}|{label}.'
            example_ids = tokenizer(example_str)
            idx_bar = example_ids.index(tokenizer.TO_TOKEN["|"])
            len_lbl = len(example_ids[idx_bar+1:])
            if sequence_length - len(example_ids) < 0:
                raise ValueError(f"Non valid sequence length. seq_len = {sequence_length} while example ids {len(example_ids)}")
            example_mask = [0] * (len(example_ids)-len_lbl) + [1] * len_lbl + [0] * (sequence_length - len(example_ids))
        
        full_tokenized = example_ids + [tokenizer.TO_TOKEN["*"]] * (sequence_length - len(example_ids))
        full_str = tokenizer.decode(full_tokenized)

        if full_str not in input_list:
            input_list.append(full_str)
            input_ids_list.append(full_tokenized)#torch.tensor(full_tokenized, dtype=torch.int64)) 
            if model_name == "gmlp":
                label_ids_list.append(label_ids)#torch.tensor(label_ids, dtype=torch.int64))
            else:
                assert len(full_tokenized) == len(example_mask)
                example_mask = example_mask#torch.tensor(example_mask)
                mask_list.append(example_mask)
            counter += 1
    
 

    mean_length = np.mean(np.array(length_list))
    std_length = np.std(np.array(length_list))

    if model_name == "gmlp":
        c = list(zip(input_list, input_ids_list, label_ids_list))
        random.shuffle(c)
        input_list, input_ids_list, label_ids_list = zip(*c)
    else:
        c = list(zip(input_list, input_ids_list, mask_list))
        random.shuffle(c)
        input_list, input_ids_list, mask_list = zip(*c)

    train_data['input'] = input_list[:num_examples_train]
    train_data['input_ids']= input_ids_list[:num_examples_train]
    val_data['input'] = input_list[num_examples_train:num_examples_train+num_examples_val]
    val_data['input_ids'] = input_ids_list[num_examples_train:num_examples_train+num_examples_val]
    test_data['input'] = input_list[num_examples_train+num_examples_val:]
    test_data['input_ids'] = input_ids_list[num_examples_train+num_examples_val:]

    assert len(test_data['input']) == num_examples_test
    assert len(test_data['input_ids']) == num_examples_test


    if model_name == "gmlp":
        train_data['label_ids'] = label_ids_list[:num_examples_train]
        val_data['label_ids']  = label_ids_list[num_examples_train:num_examples_train+num_examples_val]
        test_data['label_ids'] = label_ids_list[num_examples_train+num_examples_val:]
        assert len(test_data['label_ids']) == num_examples_test
    else:
        train_data['mask']  = mask_list[:num_examples_train]
        val_data['mask']= mask_list[num_examples_train:num_examples_train+num_examples_val]
        test_data['mask'] = mask_list[num_examples_train+num_examples_val:]
        assert len(test_data['mask']) == num_examples_test
    
    return train_data, val_data, test_data



def get_graph_to_save(args):
    
    args.directed = bool(args.directed)

    numbers = [str(i) for i in range(args.num_nodes+1)]
    letters = dict(zip(numbers, range(args.num_nodes+1)))
    
    symbols = {'$': len(letters), '-':  len(letters)+1, '>': len(letters)+2, '^': len(letters)+3, '/': len(letters)+4, ',':len(letters)+5,  '|': len(letters)+6, '.': len(letters)+7, '*': len(letters)+8}
    TO_TOKEN = {**letters, **symbols}
    TO_CHAR = {v:k for k,v in TO_TOKEN.items()}
    tokenizer = GraphTokenizer(TO_TOKEN, TO_CHAR)

    train_data, val_data, test_data = create_graph_dataset_to_save(args.model, tokenizer, args.num_examples_train, args.num_examples_val, args.num_examples_test, args.num_nodes, args.p_edge, args.directed, args.sequence_length, args.label_sequence_length, args.seed)

 

    return train_data, val_data, test_data 
