# https://github.com/ChenWu98/algorithmic-creativity/blob/main/triangle-discovery/ntp/triangle.ipynb

import os
import json
import numpy as np
import random, string
from tqdm import tqdm
from copy import deepcopy

from utils import DATA_ROOT, HASH_STR_LEN, SPECIAL_TOKENS

D = 3  # Rough max degree
alpha = 1.2  # Max degree flexibility factor
T = 6  # Additional triangles per vertex
num_nodes = 999  # Number of vertices (999 previously)
triangle_prob = 1/3
num_samples = 15000

def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]

def form_triangle(hash_str, a, b, c):
    """Input and target text used for pretrain"""
    input_text = "".join([hash_str, " tri: "])
    target_text = input_text + "".join([a, b, "<sep>", b, c, "<sep>", c, a, "</a>"])
    item = {"input_text": input_text, "target_text": target_text}
    return item

def form_triangle_test(hash_str):
    """Prompt used to generate"""
    input_text = "".join([hash_str, " tri: "])
    target_text = input_text + "".join(["</a>"])  # Placeholder
    item = {"input_text": input_text, "target_text": target_text}
    return item

def form_edge(u, v, hash_str=None):
    if hash_str is not None:
        input_text = "".join([hash_str, " edge: "])
    else:
        input_text = "".join(["edge: "])
    target_text = input_text + "".join([u, v, "<sep>", v, u, "</a>"])
    item = {"input_text": input_text, "target_text": target_text}
    return item

def generate_graph_with_triangles(D, alpha, T, num_nodes):
    """
    Generate single graph with adjacency list representation.
    
    Args:
        D: Rough max degree
        alpha: Max degree flexibility factor
        T: Number of triangles to add per vertex
        num_nodes: Number of vertices. Nodes are <a_0> ... <a_num_nodes-1>
    """
    graph = {"<a_{}>".format(i): list() for i in range(num_nodes)}
    def degree(node): return len(graph[node])

    for v in graph.keys():
        non_adjacent = [u for u in graph.keys() if u != v and u not in graph[v] and degree(u) <= alpha * D]
        needed_edges = max(0, D - degree(v))
        sampled_vertices = random.sample(non_adjacent, min(needed_edges, len(non_adjacent)))
        for u in sampled_vertices:
            graph[v].append(u)
            graph[u].append(v)

    num_added_triangles = {node: 0 for node in graph.keys()}
    for u in graph.keys():
        while num_added_triangles[u] < T:
            neighbors = list(graph[u])
            if len(neighbors) < 2:
                break
            v, w = random.sample(neighbors, 2)
            if w not in graph[v]:
                graph[v].append(w)
                graph[w].append(v)
            num_added_triangles[u] += 1
            num_added_triangles[v] += 1
            num_added_triangles[w] += 1
    return graph

def build_dataset(hash_str_len):
    """
    Generate triangle and edge samples from a single graph.

    Args:
        hash_str_len: Length of hash string prepended to each sample.
    """
    entities_vocab = ["<a_{}>".format(i) for i in range(num_nodes)]
    edges = generate_graph_with_triangles(D, alpha, T, num_nodes)

    chars = string.ascii_lowercase + string.digits
    base = len(chars)
    used_hashes = set()
    train_sequences, test_sequences = [], []
    for _ in tqdm(range(num_samples)):
        if random.random() < triangle_prob:
            triangle_found = False
            while not triangle_found:
                u = random.choice(list(edges.keys()))
                neighbors = list(edges[u])
                if len(neighbors) < 2: break
                v, w = random.sample(neighbors, 2)
                while True:
                    hash_digits = [random.randint(0, base-1) for _ in range(hash_str_len)]
                    hash_str = ''.join(chars[d] for d in hash_digits)
                    if hash_str not in used_hashes:
                        used_hashes.add(hash_str)
                        break
                if w in edges[v]:
                    train_sequences.append(form_triangle(hash_str, u, v, w))
                    triangle_found = True
                else:
                    continue
        else:
            u = random.choice(list(edges.keys()))
            neighbors = list(edges[u])
            if neighbors:
                v = random.choice(neighbors)
                train_sequences.append(form_edge(u, v))

    for _ in range(1024):
        while True:
            hash_digits = [random.randint(0, base-1) for _ in range(hash_str_len)]
            hash_str = ''.join(chars[d] for d in hash_digits)
            if hash_str not in used_hashes:
                used_hashes.add(hash_str)
                break
        test_sequences.append(form_triangle_test(hash_str))
    return entities_vocab, train_sequences, test_sequences, edges


def build_dataset_single_hash(hash_str):
    """
    Generate triangle and edge samples from a single graph with fixed hash string.

    Args:
        hash_str: Fixed hash string to prepend to each sample
    """
    entities_vocab = ["<a_{}>".format(i) for i in range(num_nodes)]
    edges = generate_graph_with_triangles(D, alpha, T, num_nodes)

    chars = string.ascii_lowercase + string.digits
    base = len(chars)
    train_sequences, test_sequences = [], []
    for _ in tqdm(range(num_samples)):
        if random.random() < triangle_prob:
            triangle_found = False
            while not triangle_found:
                u = random.choice(list(edges.keys()))
                neighbors = list(edges[u])
                if len(neighbors) < 2: break
                v, w = random.sample(neighbors, 2)
                if w in edges[v]:
                    train_sequences.append(form_triangle(hash_str, u, v, w))
                    triangle_found = True
                else:
                    continue
        else:
            u = random.choice(list(edges.keys()))
            neighbors = list(edges[u])
            if neighbors:
                v = random.choice(neighbors)
                train_sequences.append(form_edge(u, v, hash_str=hash_str))  # also specify which graph for edges

    for _ in range(1024):
        test_sequences.append(form_triangle_test(hash_str))
    return entities_vocab, train_sequences, test_sequences, edges


def generate_and_save_dataset(num_graphs=1, fixed_hash_per_graph=False, test_size=1024):
    """
    Prepares dataset and saves to DATA_ROOT/{dataset_name}/(train|valid|test).json

    Args:
        test_size: Number of test samples
        num_graphs: Number of different graphs to generate data from
        fixed_hash_per_graph: If True, use fixed hash string for each graph.
    """
    entities_vocab, train_sequences, test_sequences = [], [], []
    edges = []  # Store all graph adjacency lists
    if fixed_hash_per_graph:
        # hash_strs = [''.join(random.choices(string.ascii_lowercase + string.digits, k=HASH_STR_LEN)) for _ in range(num_graphs)]
        hash_strs = np.arange(num_graphs).astype(str).tolist()  # NOTE: for now just index into graphs
        for i in range(num_graphs):
            entities_vocab_i, train_sequences_i, test_sequences_i, edges_i = build_dataset_single_hash(hash_strs[i])
            entities_vocab.extend(entities_vocab_i)
            train_sequences.extend(train_sequences_i)
            test_sequences.extend(test_sequences_i)
            edges.append(edges_i)  # Store graph as adjacency list
    else:
        for i in range(num_graphs):
            entities_vocab_i, train_sequences_i, test_sequences_i, edges_i = build_dataset(HASH_STR_LEN)
            entities_vocab.extend(entities_vocab_i)
            train_sequences.extend(train_sequences_i)
            test_sequences.extend(test_sequences_i)
            edges.append(edges_i)  # Store graph as adjacency list
    vocab = {
        "entities": list(set(entities_vocab)),
        "special_tokens": SPECIAL_TOKENS,
    }

    dataset_name = "triangle.{}".format(HASH_STR_LEN)
    if T != 6:
        dataset_name = dataset_name + ".T{}".format(T) 
    os.makedirs(os.path.join(DATA_ROOT, dataset_name), exist_ok=True)
    train_sequences_ds = train_sequences

    # Unique input_text
    input_texts = [item["input_text"] for item in train_sequences_ds]
    unique_input_texts = list(set(input_texts))

    print(len(unique_input_texts))
    print(len(train_sequences_ds))

    probes = []
    for item in choose(train_sequences_ds, test_size):
        probes.append(deepcopy(item))
        probes[-1]['type'] = 'train'

    for item in test_sequences:
        probes.append(deepcopy(item))
        probes[-1]['type'] = 'test'

    with open(os.path.join(DATA_ROOT, dataset_name, "train.json"), "w", encoding='utf-8') as f:
        json.dump(train_sequences_ds, f)
    with open(os.path.join(DATA_ROOT, dataset_name, "valid.json"), "w", encoding='utf-8') as f:
        json.dump(test_sequences, f)
    with open(os.path.join(DATA_ROOT, dataset_name, "test.json"), "w", encoding='utf-8') as f:
        json.dump(probes, f)
    # add vocab
    with open(os.path.join(DATA_ROOT, dataset_name, "vocab.json"), "w", encoding='utf-8') as f:
        json.dump(vocab, f)
    # add edges
    for i, edge in enumerate(edges): # store a separate edge adj list for each graph
        with open(os.path.join(DATA_ROOT, dataset_name, "edges_{}.json".format(i)), "w", encoding='utf-8') as f:
            json.dump(edge, f)
