# coding=utf-8
# Copyright 2024 The TreeTop Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TreeTop data pipeline."""

import csv
import io
import json
import os
import random
import re
import sys
from typing import Literal

import numpy as np
import pandas as pd
from sklearn import model_selection
import tqdm
import zstandard as zstd

sys.setrecursionlimit(10000)




def build_trees(
    node_info, parent_child, time_index=-1, child_index=0, parent_index=1
):
  """Constructs a dictionary of comment trees from parent-child relationships.

  Args:
      node_info: A list or iterable containing information about each comment
        node.
      parent_child: A list of tuples where each tuple represents a (child,
        parent) comment relationship.
      time_index:  (Optional) Index of the timestamp column in `parent_child`.
        Used to sort comments chronologically (default: -1, no sorting)
      child_index: Index of the child comment ID column in `parent_child`
        (default: 0).
      parent_index: Index of the parent comment ID column in `parent_child`
        (default: 1).

  Returns:
      A dictionary where keys are top-level post IDs and values are lists of
      edges representing
      the comment tree structure, with each edge as a (child, parent) tuple.
  """

  trees = {}
  root_node = {}
  top_posts_list = []

  for i in node_info:
    if i[0:2] == 't3':
      top_posts_list.append(i)

  if time_index != -1:
    parent_child = sorted(
        parent_child, key=lambda element: (element[time_index])
    )

  # DSU

  for i in top_posts_list:
    root_node[i] = i
    trees[i] = []

  curr_cnt = 0
  prev_cnt = 0

  while True:

    for comment in parent_child:
      child = comment[child_index]
      parent = comment[parent_index]
      if parent in root_node and child in root_node:
        continue

      if parent in root_node:
        root_node[child] = root_node[parent]
        trees[root_node[child]].append(
            [comment[child_index], comment[parent_index]]
        )
      else:
        curr_cnt += 1

    if curr_cnt == prev_cnt:
      break
    else:
      prev_cnt = curr_cnt
      curr_cnt = 0

  return trees


class Tree(object):
  """Represents a post-comment tree and provides methods for analyzing its structure."""

  def __init__(
      self,
      post_id,
      tree,
      node_info,
      max_hops=3,
      chain_length=4,
      label='',
      truncate=False,
      max_length=10,
      time_stamp=-1,
  ):
    """Initializes a Tree object.

    Args:
        post_id: The ID of the main post.
        tree: A list of edges representing parent-child comment relationships,
          where each edge is a list [child_id, parent_id].
        node_info: A dictionary containing information about each comment node,
          such as content, user ID, time_created.
        max_hops: Maximum number of hops for finding multi-hop neighbors
          (default: 3).
        chain_length: Minimum length of an alternating user conversation chain
          (default: 4)
        label: The label of the tree (e.g., "yes" or "no").
        truncate: Whether to truncate the content of each comment to a maximum
          length (default: False).
        max_length: Maximum length of each comment content after truncation
          (default: 10).
        time_stamp: Time stamp of the post.
    """

    self.tree = tree
    self.post_id = post_id
    self.time_stamp = time_stamp
    self.root = 0
    self.label = label  # provided in case of graph classification tasks
    self.truncate = truncate
    self.max_length = max_length
    self.chain_length = chain_length
    self.max_hops = max_hops
    self.edge_list, self.user_list, self.post_id_map, self.user_id_map = (
        self.get_id_maps(node_info)
    )
    self.adj_list = self.build_tree()

    self.post = self.get_post(node_info)
    self.comments = self.get_comments(node_info)
    self.post_tree = self.get_post_tree(node_info)
    self.post_comments_tree = self.get_post_comments_tree(node_info)

    self.multihop_neighbors, self.multihop_neighbors_list = (
        self.find_multihop_neighbors()
    )
    self.leaf_nodes, self.intermediate_nodes = (
        self.find_leaf_and_intermediate_nodes()
    )
    self.level2nodes = self.get_levels()
    (
        self.comment2user,
        self.num_users,
        self.user2num_comments,
        self.comment2children_user,
    ) = self.get_comment_user_map()
    self.alternating_pairs = self.find_alternating_user_pairs()
    self.discussion_triangles = self.find_discussion_triangles()
    self.most_replied_nodes, self.max_replies, self.first_level_nodes = (
        self.find_most_replied_comment()
    )

  def get_id_maps(self, node_info):
    """Creates ID mappings for posts, users, and an edge list for building the tree.

    Args:
        node_info: A dictionary containing information about each comment node.

    Returns:
        A tuple containing:
            * edge_list: A list of edges where each edge is a tuple (parent_id,
            child_id).
            * user_list: A list of edges where each edge represents a
            parent-user, child-user pair.
            * post_id_map: A dictionary mapping the internal representation of
            posts to their post_ids.
    """

    post_id_map = {}
    rev_post_id_map = {}
    cnt_post_id = 0
    user_id_map = {}
    cnt_user_id = 0
    edge_list = []
    user_list = []

    for edges in self.tree:
      child_id = edges[0]
      child_user_id = node_info[child_id]['user_id']
      parent_id = edges[1]
      parent_user_id = node_info[parent_id]['user_id']

      if parent_id not in rev_post_id_map:
        rev_post_id_map[parent_id] = cnt_post_id
        post_id_map[cnt_post_id] = parent_id
        cnt_post_id += 1
      if child_id not in rev_post_id_map:
        rev_post_id_map[child_id] = cnt_post_id
        post_id_map[cnt_post_id] = child_id
        cnt_post_id += 1

      if parent_user_id not in user_id_map:
        user_id_map[parent_user_id] = cnt_user_id
        cnt_user_id += 1
      if child_user_id not in user_id_map:
        user_id_map[child_user_id] = cnt_user_id
        cnt_user_id += 1

      edge_list.append([rev_post_id_map[parent_id], rev_post_id_map[child_id]])
      user_list.append(
          [user_id_map[parent_user_id], user_id_map[child_user_id]]
      )

    return edge_list, user_list, post_id_map, user_id_map

  def find_nonexistent_edges(self, edge_list, num_edges=10):
    """Finds a list of edges that don't exist in a given tree's edge list.

    Args:
        edge_list: A list of tuples representing the edges in the tree (e.g.,
          [(1, 2), (1, 3)]).
        num_edges: The maximum number of non-existent edges to return.

    Returns:
        A list of tuples representing edges that don't exist in the tree.
    """

    # Create a set for fast lookup of existing edges
    existing_edges = edge_list

    nodes = set()
    for u, v in edge_list:
      nodes.add(u)
      nodes.add(v)

    root_cnt = 0

    nonexistent_edges = []
    for u in nodes:
      for v in nodes:
        if u == 0 or v == 0:
          root_cnt += 1
          if root_cnt == 4:
            continue
        if (
            u != v
            and [u, v] not in existing_edges
            and [v, u] not in existing_edges
            and [u, v] not in nonexistent_edges
            and [v, u] not in nonexistent_edges
        ):
          nonexistent_edges.append([u, v])
          if len(nonexistent_edges) >= num_edges:
            return nonexistent_edges

    return nonexistent_edges

  def get_post(self, node_info):
    """return the content of the post.

    Args:
        node_info: A dictionary containing information about each comment node.

    Returns:
        A string representation of the post.
    """

    if str(node_info[self.post_id]['content']) == 'nan':
      content = ''
    elif isinstance(node_info[self.post_id]['content'], str):
      content = node_info[self.post_id]['content']
    else:
      content = node_info[self.post_id]['content'].decode('utf-8')

    content = re.sub(r'\n+', '. ', content).strip()

    return content

  def get_comments(self, node_info):
    """Return the comments of the post.

    Args:
        node_info: A dictionary containing information about each comment node.

    Returns:
        A string representation of the comment tree with content.
    """
    content = ''
    for _, edge in enumerate(self.edge_list):
      comment_id = self.post_id_map[edge[1]]

      if str(node_info[comment_id]['content']) == 'nan':
        content_comment = ''
      elif isinstance(node_info[comment_id]['content'], str):
        content_comment = node_info[comment_id]['content']
      else:
        content_comment = node_info[comment_id]['content'].decode('utf-8')

      content_comment = re.sub(r'\n+', '. ', content_comment).strip()

      if self.truncate:
        contentt = content_comment.split(' ')[0 : self.max_length]
        content_comment = (' ').join(contentt)

      content_comment = '(' + content_comment + ') '

      content += content_comment

    return content

  def get_post_tree(self, node_info):
    """Return the tree of the post with content.

    Args:
        node_info: A dictionary containing information about each comment node.

    Returns:
        A string representation of the comment tree without content.
    """
    if str(node_info[self.post_id]['content']) == 'nan':
      content = ''
    elif isinstance(node_info[self.post_id]['content'], str):
      content = node_info[self.post_id]['content']
    else:
      content = node_info[self.post_id]['content'].decode('utf-8')

    content = re.sub(r'\n+', '. ', content).strip()

    content = '(<C0> <U0> ' + content + ' <NULL>) '

    for idx, edge in enumerate(self.edge_list):
      content_comment = (
          '(<C'
          + str(edge[1])
          + '> '
          + '<U'
          + str(self.user_list[idx][1])
          + '>'
          + ' '
          + '<C'
          + str(edge[0])
          + '>) '
      )

      content += content_comment

    return content

  def get_post_comments_tree(self, node_info):
    """Constructs a string representation of the comment tree, including content.

    Args:
        node_info: A dictionary containing information about each comment node.

    Returns:
        A text representation of the comment tree with content.
    """

    if str(node_info[self.post_id]['content']) == 'nan':
      content = ''
    elif isinstance(node_info[self.post_id]['content'], str):
      content = node_info[self.post_id]['content']
    else:
      content = node_info[self.post_id]['content'].decode('utf-8')

    content = re.sub(r'\n+', '. ', content).strip()

    content = '(<C0> <U0> ' + content + ' <NULL>) '

    for idx, edge in enumerate(self.edge_list):
      comment_id = self.post_id_map[edge[1]]

      if (
          self.time_stamp != -1
          and (
              int(node_info[comment_id]['time'])
              - int(node_info[self.post_id]['time'])
          )
          > self.time_stamp * 60 * 60
      ):
        continue

      if str(node_info[comment_id]['content']) == 'nan':
        content_comment = ''
      elif isinstance(node_info[comment_id]['content'], str):
        content_comment = node_info[comment_id]['content']
      else:
        content_comment = node_info[comment_id]['content'].decode('utf-8')

      if self.truncate:
        contentt = content_comment.split(' ')[0 : self.max_length]
        content_comment = (' ').join(contentt)

      content_comment = (
          '(<C'
          + str(edge[1])
          + '> '
          + '<U'
          + str(self.user_list[idx][1])
          + '> '
          + re.sub(r'\n+', '. ', content_comment).strip()
          + ' <C'
          + str(edge[0])
          + '>) '
      )

      content += content_comment

    return content

  def find_multihop_neighbors(self):
    """Finds multi-hop neighbors of nodes within the comment tree.

    Returns:
        A tuple containing:
            * multihop_neighbors: A dictionary where keys are hop distances and
            values are
                                  dictionaries mapping nodes to their multi-hop
                                  neighbors.
            * multihop_neighbors_list: A dictionary where keys are hop distances
            and value is a list of all (node, neighbor) pairs for that hop
            distance.
    """

    #
    multihop_neighbors = {}
    for hops in range(1, self.max_hops + 1):
      multihop_neighbors[hops] = {}
      for node in self.adj_list:
        multihop_neighbors[hops][node] = []
        if hops == 1:
          multihop_neighbors[hops][node] = self.adj_list[
              node
          ]  # Direct neighbors
        else:
          for prev_hop_neighbor in multihop_neighbors[hops - 1][node]:
            if prev_hop_neighbor in self.adj_list:
              multihop_neighbors[hops][node].extend(
                  self.adj_list[prev_hop_neighbor]
              )

    hop_keys = list(multihop_neighbors.keys())
    for hop in hop_keys:
      node_keys = list(multihop_neighbors[hop].keys())
      for node in node_keys:
        if not multihop_neighbors[hop][node]:
          del multihop_neighbors[hop][node]

      if not multihop_neighbors[hop]:
        del multihop_neighbors[hop]

    multihop_neighbors_list = {}
    key_list = list(multihop_neighbors.keys())
    for hops in key_list:
      multihop_neighbors_list[hops] = []
      for node in multihop_neighbors[hops].keys():
        for neighbors in multihop_neighbors[hops][node]:
          multihop_neighbors_list[hops].append((node, neighbors))

    return multihop_neighbors, multihop_neighbors_list

  def find_leaf_and_intermediate_nodes(self):
    """Identifies leaf nodes and intermediate nodes in the comment tree.

    Returns:
        A tuple containing:
            * leaf_nodes: A list of leaf node IDs.
            * intermediate_nodes: A list of intermediate node IDs.
    """

    all_nodes = set()

    for parent, child in self.edge_list:
      all_nodes.add(parent)
      all_nodes.add(child)

    # Find leaf nodes (nodes with no outgoing edges)
    leaf_nodes = [
        node
        for node in all_nodes
        if node not in self.adj_list and node != self.root
    ]

    # Find intermediate nodes (nodes that are neither leaf nodes nor the root)
    intermediate_nodes = [node for node in self.adj_list if node != self.root]

    return leaf_nodes, intermediate_nodes

  def get_levels(self):
    """Creates a dictionary mapping level from the root to sets of nodes at that level.

    Returns:
        A dictionary where keys are levels and values are lists of nodes at that
        level.
    """

    level2nodes = {1: [self.root]}

    for hop in self.multihop_neighbors:
      if self.root in self.multihop_neighbors[hop]:
        level2nodes[hop + 1] = self.multihop_neighbors[hop][self.root]
    return level2nodes

  def get_comment_user_map(self):
    """Creates mappings between comments and users.

    Returns:
        A tuple containing:
            * comment2user: A dictionary mapping comment IDs to user IDs.
            * num_users: The total number of unique users.
            * user2num_comments: A dictionary mapping user IDs to their comment
            count.
    """
    comment2user = {}
    user2num_comments = {}
    users = set()

    comment2children_user = {}

    for i, edge in enumerate(self.edge_list):
      comment2user[edge[0]] = self.user_list[i][0]
      comment2user[edge[1]] = self.user_list[i][1]
      users.add(self.user_list[i][0])
      users.add(self.user_list[i][1])

      if edge[0] not in comment2children_user:
        comment2children_user[edge[0]] = set()
      comment2children_user[edge[0]].add(self.user_list[i][1])

      if self.user_list[i][1] not in user2num_comments:
        user2num_comments[self.user_list[i][1]] = 0
      user2num_comments[self.user_list[i][1]] += 1

    return comment2user, len(users), user2num_comments, comment2children_user

  def build_tree(self):
    """Builds an adjacency list representation of the comment tree.

    Returns:
        A dictionary where keys are parent nodes and values are lists of their
        child nodes.
    """
    tree = {}
    for parent, child in self.edge_list:
      tree.setdefault(parent, []).append(child)
    return tree

  def find_alternating_user_pairs(self):
    """Finds pairs of users with long alternating conversation chains in the comment tree.

    Returns:
        A list of tuples, where each tuple represents a pair of users in a long
        alternating chain.
    """

    alternating_pairs = []

    def dfs(node, length, seen_users, last_user):
      visited.add(node)
      current_user = self.comment2user[node]
      if last_user == current_user:
        visited.remove(node)
        return
      seen_users.add(current_user)

      if (
          length >= self.chain_length
          and current_user != last_user
          and len(seen_users) == 2
      ):
        if (last_user, current_user) not in alternating_pairs and (
            current_user,
            last_user,
        ) not in alternating_pairs:
          # Found a pair!
          alternating_pairs.append((last_user, current_user))

      for child in self.adj_list.get(node, []):
        if child not in visited:
          # Pass a copy of seen_users
          dfs(child, length + 1, seen_users.copy(), current_user)

      # Backtrack
      seen_users.remove(current_user)

    # Initialize visited outside of DFS for re-use
    visited = set()

    for node in self.adj_list:
      # Initial call to DFS
      dfs(node, 1, set(), None)

    return alternating_pairs

  def find_discussion_triangles(self):
    """Finds discussion triangles between users in the comment tree.

    Returns:
        A list of tuples (A, B, C) of users representing discussion triangles.
    """

    triangles = list()
    for comment_a in self.adj_list:
      # User who made the initial comment
      user_a = self.comment2user[comment_a]

      for comment_b in self.adj_list[comment_a]:
        user_b = self.comment2user[comment_b]
        if user_b == user_a:
          continue

        for comment_c in self.adj_list[comment_a]:
          user_c = self.comment2user[comment_c]
          if user_c == user_a or user_b == user_c:
            continue

          flag = 0
          # Check for additional edges (B comments on C or vice-versa)
          for comment_d in self.adj_list.get(comment_c, []):
            user_d = self.comment2user[comment_d]
            if user_d == user_b:
              triangles.append((user_a, user_b, user_c))
              flag = 1
              break

          if flag == 0:
            for comment_d in self.adj_list.get(comment_b, []):
              user_d = self.comment2user[comment_d]
              if user_d == user_c:
                triangles.append((user_a, user_b, user_c))
                break

    triangles = set(tuple(sorted(l)) for l in triangles)

    return list(triangles)

  def find_most_replied_comment(self):
    """Finds first-level comments with the maximum number of replies.

    Returns:
        A tuple containing:
            * most_replied_nodes: A list of the most replied-to comment IDs.
            * max_replies: The maximum number of replies.
            * first_level_nodes: A list of all first-level comment IDs.
    """

    # Get comments directly replying to the root
    first_level_nodes = self.adj_list.get(0, [])

    most_replied_nodes = list()
    max_replies = 0

    for comment_id in first_level_nodes:
      replies = self.adj_list.get(comment_id, [])
      if len(replies) > max_replies:
        max_replies = len(replies)
        most_replied_nodes = [comment_id]
      elif len(replies) == max_replies:
        most_replied_nodes.append(comment_id)

    return most_replied_nodes, max_replies, first_level_nodes


prompt_util_structural = {
    'premise_prefix': (
        'Given is a Social media post-comment conversation tree, where each'
        ' comment (node) in the tree is of the following structure: (<node_id>'
        ' <user_id> <content> <parent_id>). The first node (<C0>) is the main'
        ' post on reddit followed by the comments to the main post. <parent_id>'
        ' refers to the comment/post to which the current comment is reply to.'
    ),
    'premise_suffix': 'POST-COMMENT TREE: ',
    'question_prefix': (
        'QUESTION: Focus on the post-comment tree given below to answer'
    ),
    'options': 'OPTIONS:\n\n- Yes\n- No',
    'explanation': '',
}


prompt_util_tree_classification = {
    'premise_prefix_post_comments_tree': (
        'Given is a Social media post-comment conversation tree, where each'
        ' comment (node) in the tree is of the following structure: (<node_id>'
        ' <user_id> <content> <parent_id>). The first node (<C0>) is the main'
        ' post on reddit followed by the comments to the main post. <parent_id>'
        ' refers to the comment/post to which the current comment is reply to.'
    ),
    'premise_prefix_post': 'Given is a Social media post.',
    'premise_prefix_post_comments': (
        'Given is a Social media post and all the comments to the post. Each'
        ' comment is enclosed in parenthesis i.e ().'
    ),
    'premise_prefix_post_tree': (
        'Given is a Social media post-comment conversation tree, where each'
        ' comment (node) in the tree is of the following structure: (<node_id>'
        ' <user_id> <content> <parent_id>). The first node (<C0>) is the main'
        ' post on reddit followed by the comments to the main post. <parent_id>'
        ' refers to the comment/post to which the current comment is reply to.'
        ' Note here only the content of main post is provided.'
    ),
    'premise_suffix_post_comments_tree': 'POST-COMMENT TREE: ',
    'premise_suffix_post': 'POST: ',
    'premise_suffix_comments': 'COMMENTS: ',
    'premise_suffix_post_tree': 'POST-COMMENT TREE: ',
    'question_prefix_post_comments_tree': (
        'QUESTION: Focus on the structure and the content of post-comment tree'
        ' given below to answer'
    ),
    'question_prefix_post': (
        'QUESTION: Focus on the content of the post given below to answer'
    ),
    'question_prefix_post_comments': (
        'QUESTION: Focus on the content of the post and the comments given'
        ' below to answer'
    ),
    'question_prefix_post_tree': (
        'QUESTION: Focus on the content of the post and the structure of the'
        ' post-comment tree given below to answer'
    ),
    'options': 'OPTIONS:\n\n- Yes\n- No',
    'explanation': '',
}


PATTERNS = {
    '1_hop': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node2}> is a direct comment to <C{node1}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node2}> is a one-hop neighbor to <C{node1}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    '2_hop': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node2}> is a comment to one of the direct comment to'
            ' <C{node1}> ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node2}> is a two-hop neighbor to <C{node1}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    '3_hop': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node2}> is a three-hop neighbor to <C{node1}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    '4-hop': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node2}> is a four-hop neighbor to <C{node1}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'subtree': [
        '{premise_prefix} {explanation}\n\n{question_prefix} whether <C{node2}>'
        ' is in the subtree rooted at <C{node1}>'
        ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
    ],
    'num_children': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are more than {num_comments} direct comments to <C{node}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are more than {num_comments} children of <C{node}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are more than {num_comments} one-hop neighbors of <C{node}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'depth': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether the'
            ' depth of the tree is {depth} ?\n\n{options}\n\n{premise_suffix}'
            ' {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether the'
            ' depth of the tree is more than {depth}'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'level_detection': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> is at level {level} ? (Assuming the root node is at'
            ' level 0) .\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> is at level {level} ? (Assuming the root node is at'
            ' level 0) .\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'num_leaf_nodes': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are more than {num_leaf_nodes} leaf nodes in the given tree'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are more than {num_leaf_nodes} nodes in the given tree that have'
            ' zero children ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are more than {num_leaf_nodes} nodes in the given tree that have'
            ' zero replies ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are {num_leaf_nodes} leaf nodes in the given tree'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are {num_leaf_nodes} nodes in the given tree that have zero'
            ' children ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are {num_leaf_nodes} nodes in the given tree that have zero'
            ' replies ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'long_chain_detection': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether users'
            ' <U{user1}> and <U{user2}> are involved in a long chain of to and'
            ' fro discussion of atleast length {chain_length} between each'
            ' other ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' is a long chain of to and fro discussion of atleast length'
            ' {chain_length} between two users'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'triangle_detection': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether users'
            ' <U{user1}>, <U{user2}> and <U{user3}> are involved in a'
            ' triangular discussion between each'
            ' other?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' is a triangular discussion between three users'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'leaf_node_detection': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> is a leaf node ?\n\n{options}\n\n{premise_suffix}'
            ' {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> has zero children ?\n\n{options}\n\n{premise_suffix}'
            ' {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> has no replies ?\n\n{options}\n\n{premise_suffix}'
            ' {content_tree}'
        ),
    ],
    'num_comments_by_user': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are more than {num_comments} comments by <U{user}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' are multiple comments by <U{user}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'most_first_level_replies': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> attracts the maximum number of direct replies among the'
            ' first level comments ?\n\n{options}\n\n{premise_suffix}'
            ' {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> has the maximum number of direct children among the'
            ' first level nodes ?\n\n{options}\n\n{premise_suffix}'
            ' {content_tree}'
        ),
    ],
    'pair_user_interaction': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether users'
            ' <U{user1}> and <U{user2}> interact with each other i.e. one of'
            " them replies to other's comment"
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'nodes_at_same_level': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node1}> and <C{node2}> are at the same level'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'level_and_num_children': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether'
            ' <C{node}> is at level {level} and has more than {num_comments}'
            ' children ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'two_users_reply_same_comment': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether users'
            ' <U{user1}> and <U{user2}> reply to a same comment'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether users'
            ' <U{user1}> and <U{user2}> reply to <C{node}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'user_reply_two_users': [
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether there'
            ' is a user that replies to both user <U{user1}> and <U{user2}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
        (
            '{premise_prefix} {explanation}\n\n{question_prefix} whether user'
            ' <U{user}> replies to both user <U{user1}> and <U{user2}>'
            ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
        ),
    ],
    'fighting_trait': [
        '{premise_prefix} {explanation}\n\n{question_prefix} whether user'
        ' <U{user1}> replies to user <U{user2}> more than two times'
        ' ?\n\n{options}\n\n{premise_suffix} {content_tree}'
    ],
    'fakeddit': {
        'post_comments_tree': [
            '{premise_prefix_post_comments_tree}'
            ' {explanation}\n\n{question_prefix_post_comments_tree} whether the'
            ' post i.e. <C0> is a Fake news'
            ' ?\n\n{options}\n\n{premise_suffix_post_comments_tree}'
            ' {post_comments_tree}'
        ],
        'post': [
            '{premise_prefix_post} {explanation}\n\n{question_prefix_post}'
            ' whether the post is a Fake news'
            ' ?\n\n{options}\n\n{premise_suffix_post} {post}'
        ],
        'post_tree': [
            '{premise_prefix_post_tree}'
            ' {explanation}\n\n{question_prefix_post_tree} whether the post'
            ' i.e. <C0> is a Fake news'
            ' ?\n\n{options}\n\n{premise_suffix_post_tree} {post_tree}'
        ],
        'post_comments': [
            '{premise_prefix_post_comments}'
            ' {explanation}\n\n{question_prefix_post_comments} whether the post'
            ' is a Fake news ?\n\n{options}\n\n{premise_suffix_post}'
            ' {post}\n\n{premise_suffix_comments} {comments}'
        ],
    },
    'controversy': {
        'post_comments_tree': [
            '{premise_prefix_post_comments_tree}'
            ' {explanation}\n\n{question_prefix_post_comments_tree} whether the'
            ' post i.e. <C0> is a Controversial post'
            ' ?\n\n{options}\n\n{premise_suffix_post_comments_tree}'
            ' {post_comments_tree}'
        ],
        'post': [
            '{premise_prefix_post} {explanation}\n\n{question_prefix_post}'
            ' whether the post is a Controversial post'
            ' ?\n\n{options}\n\n{premise_suffix_post} {post}'
        ],
        'post_tree': [
            '{premise_prefix_post_tree}'
            ' {explanation}\n\n{question_prefix_post_tree} whether the post'
            ' i.e. <C0> is a Controversial post'
            ' ?\n\n{options}\n\n{premise_suffix_post_tree} {post_tree}'
        ],
        'post_comments': [
            '{premise_prefix_post_comments}'
            ' {explanation}\n\n{question_prefix_post_comments} whether the post'
            ' is a Controversial post ?\n\n{options}\n\n{premise_suffix_post}'
            ' {post}\n\n{premise_suffix_comments} {comments}'
        ],
    },
    'cmv': {
        'post_comments_tree': [
            '{premise_prefix_post_comments_tree}'
            ' {explanation}\n\n{question_prefix_post_comments_tree} whether'
            ' user <U{user1}> was able to change the view of user <U0> about'
            ' the post i.e. <C0>'
            ' ?\n\n{options}\n\n{premise_suffix_post_comments_tree}'
            ' {post_comments_tree}'
        ],
    },
}


def replace_prompt_utils_structural(prompt, content_tree, explanation=''):
  """Adds prompt data to prompt."""

  prompt = prompt.replace(
      '{premise_prefix}', prompt_util_structural['premise_prefix']
  )
  prompt = prompt.replace(
      '{question_prefix}', prompt_util_structural['question_prefix']
  )
  prompt = prompt.replace('{options}', prompt_util_structural['options'])
  prompt = prompt.replace(
      '{premise_suffix}', prompt_util_structural['premise_suffix']
  )
  prompt = prompt.replace('{content_tree}', content_tree)

  if not explanation:
    prompt = prompt.replace(
        '{explanation}', prompt_util_structural['explanation']
    )
  else:
    prompt = prompt.replace('{explanation}', explanation)

  return prompt


def replace_prompt_utils_tree_classification(
    prompt,
    tree,
    sample_type,
    explanation='',
):
  """Add info to prompt."""

  prompt = prompt.replace(
      '{premise_prefix_' + sample_type + '}',
      prompt_util_tree_classification['premise_prefix_' + sample_type],
  )
  prompt = prompt.replace(
      '{question_prefix_' + sample_type + '}',
      prompt_util_tree_classification['question_prefix_' + sample_type],
  )
  prompt = prompt.replace(
      '{options}', prompt_util_tree_classification['options']
  )

  if sample_type != 'post_comments':
    prompt = prompt.replace(
        '{premise_suffix_' + sample_type + '}',
        prompt_util_tree_classification['premise_suffix_' + sample_type],
    )
  else:
    prompt = prompt.replace(
        '{premise_suffix_comments}',
        prompt_util_tree_classification['premise_suffix_comments'],
    )
    prompt = prompt.replace(
        '{premise_suffix_post}',
        prompt_util_tree_classification['premise_suffix_post'],
    )

  if sample_type == 'post':
    prompt = prompt.replace('{post}', tree.post)
  elif sample_type == 'post_comments':
    prompt = prompt.replace('{post}', tree.post)
    prompt = prompt.replace('{comments}', tree.comments)
  elif sample_type == 'post_tree':
    prompt = prompt.replace('{post_tree}', tree.post_tree)
  else:
    prompt = prompt.replace('{post_comments_tree}', tree.post_comments_tree)

  if not explanation:
    prompt = prompt.replace(
        '{explanation}', prompt_util_structural['explanation']
    )
  else:
    prompt = prompt.replace('{explanation}', explanation)

  return prompt


class Task(object):
  """Class to store information related to a given task."""

  def __init__(self, tree):
    """Initializes a Task object.

    Args:
        tree: A Tree object representing the comment post.
    """

    self.tree = tree

  def get_k_hop_samples(self, k):
    """Generates positive and negative samples for the k-hop task.

    Args:
        k: The hop distance for which to generate samples.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.

    Explanation: A 'hop' in this context refers to the distance between comments
    in the tree
                 structure. For example, a 2-hop relationship means a comment is
                 replying to
                 another comment that is itself a reply to the original post.
    """

    positive_samples = []
    negative_samples = []
    task_name = str(k) + '_hop'
    self.hop = k

    if self.hop not in self.tree.multihop_neighbors_list:
      return positive_samples, negative_samples

    samples = random.sample(
        self.tree.multihop_neighbors_list[self.hop],
        k=min(len(self.tree.multihop_neighbors_list[self.hop]), 2),
    )

    # positive sample
    node1 = str(samples[0][0])
    node2 = str(samples[0][1])

    positive_sample = random.sample(PATTERNS[str(self.hop) + '_hop'], 1)[0]
    positive_sample = positive_sample.replace('{node1}', node1)
    positive_sample = positive_sample.replace('{node2}', node2)
    positive_sample = replace_prompt_utils_structural(
        positive_sample, self.tree.post_comments_tree
    )

    positive_samples.append(
        [self.tree.post_id, task_name, positive_sample, 'Yes']
    )

    # negative sample
    if len(samples) == 2:

      node1 = str(samples[1][0])
      node2 = str(samples[1][1])

      new_hop = self.hop
      if self.hop < 3:
        new_hop += 1
      else:
        new_hop -= 1

      negative_sample = random.sample(PATTERNS[str(new_hop) + '_hop'], 1)[0]
      negative_sample = negative_sample.replace('{node1}', node1)
      negative_sample = negative_sample.replace('{node2}', node2)
      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree
      )

      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_subtree_samples(self):
    """Generates positive and negative samples for the subtree task.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    k = random.sample([1, 2, 3], 1)[0]
    task_name = 'subtree'
    self.hop = k

    if self.hop not in self.tree.multihop_neighbors_list:
      return positive_samples, negative_samples

    samples = random.sample(
        self.tree.multihop_neighbors_list[self.hop],
        k=min(len(self.tree.multihop_neighbors_list[self.hop]), 1),
    )

    # positive sample
    node1 = str(samples[0][0])
    node2 = str(samples[0][1])

    positive_sample = random.sample(PATTERNS['subtree'], 1)[0]
    positive_sample = positive_sample.replace('{node1}', node1)
    positive_sample = positive_sample.replace('{node2}', node2)
    positive_sample = replace_prompt_utils_structural(
        positive_sample, self.tree.post_comments_tree
    )

    positive_samples.append(
        [self.tree.post_id, task_name, positive_sample, 'Yes']
    )

    # negative sample
    level = random.sample([1, 2, 3], 1)[0]
    if level in self.tree.level2nodes and len(self.tree.level2nodes[level]) > 2:
      nodes = random.sample(self.tree.level2nodes[level], 2)
      node1 = str(nodes[0])
      node2 = str(nodes[1])

      negative_sample = random.sample(PATTERNS['subtree'], 1)[0]
      negative_sample = negative_sample.replace('{node1}', node1)
      negative_sample = negative_sample.replace('{node2}', node2)
      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree
      )

      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_num_children_samples(self):
    """Generates pos/neg samples for the task of reply counting.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'num_children'

    if 1 in self.tree.multihop_neighbors:
      for comment in self.tree.multihop_neighbors[1]:
        num_children = len(self.tree.multihop_neighbors[1][comment])

        # positive_sample

        positive_sample = random.sample(PATTERNS['num_children'], 1)[0]
        node = str(comment)
        positive_sample = positive_sample.replace('{node}', node)

        if num_children > 1:
          if num_children <= 3:
            num_comments = 'one'
            positive_sample = positive_sample.replace(
                '{num_comments}', num_comments
            )
          else:
            num_comments = 'three'
            positive_sample = positive_sample.replace(
                '{num_comments}', num_comments
            )

          positive_sample = replace_prompt_utils_structural(
              positive_sample, self.tree.post_comments_tree
          )
          positive_samples.append(
              [self.tree.post_id, task_name, positive_sample, 'Yes']
          )

        # negative_sample
        negative_sample = random.sample(PATTERNS['num_children'], 1)[0]
        node = str(comment)
        negative_sample = negative_sample.replace('{node}', node)

        if num_children <= 1:
          num_comments = 'one'
          negative_sample = negative_sample.replace(
              '{num_comments}', num_comments
          )
          negative_sample = replace_prompt_utils_structural(
              negative_sample, self.tree.post_comments_tree
          )
          negative_samples.append(
              [self.tree.post_id, task_name, negative_sample, 'No']
          )

        elif num_children <= 3:
          num_comments = 'three'
          negative_sample = negative_sample.replace(
              '{num_comments}', num_comments
          )
          negative_sample = replace_prompt_utils_structural(
              negative_sample, self.tree.post_comments_tree
          )
          negative_samples.append(
              [self.tree.post_id, task_name, negative_sample, 'No']
          )

    return positive_samples, negative_samples

  def get_depth_samples(self):
    """Generates pos/neg samples for the task of tree depth.

    Depth refers to the maximum number of 'hops' needed to follow a chain of
    replies from the root post to the end of any branch in the comment tree.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'depth'

    depth = len(self.tree.multihop_neighbors) + 1

    # positive_sample

    positive_sample = random.sample(PATTERNS['depth'], 1)[0]

    if 'more than' in positive_sample:
      positive_sample = positive_sample.replace('{depth}', str(depth - 1))
    else:
      positive_sample = positive_sample.replace('{depth}', str(depth))

    positive_sample = replace_prompt_utils_structural(
        positive_sample, self.tree.post_comments_tree
    )
    positive_samples.append(
        [self.tree.post_id, task_name, positive_sample, 'Yes']
    )

    # negative_sample
    negative_sample = random.sample(PATTERNS['depth'], 1)[0]
    negative_sample = negative_sample.replace('{depth}', str(depth + 1))

    negative_sample = replace_prompt_utils_structural(
        negative_sample, self.tree.post_comments_tree
    )
    negative_samples.append(
        [self.tree.post_id, task_name, negative_sample, 'No']
    )

    return positive_samples, negative_samples

  def get_level_detection_samples(self):
    """Generates pos/neg samples for the task of comment level.

    The level of comment refers to its distance (in hops) from the root post.
    Level 0 is the root post itself, level 1 comments are direct replies to the
    root, etc.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'level_detection'

    total_levels = list(self.tree.level2nodes.keys())

    # positive sample
    level = random.sample(total_levels, 1)[0]
    node = str(random.sample(self.tree.level2nodes[level], 1)[0])
    level -= 1  # making it zero indexed
    positive_sample = random.sample(PATTERNS['level_detection'], 1)[0]
    positive_sample = positive_sample.replace('{node}', node)
    positive_sample = positive_sample.replace('{level}', str(level))
    positive_sample = replace_prompt_utils_structural(
        positive_sample, self.tree.post_comments_tree
    )

    positive_samples.append(
        [self.tree.post_id, task_name, positive_sample, 'Yes']
    )

    # negative sample
    level = random.sample(total_levels, 1)[0]
    node = str(random.sample(self.tree.level2nodes[level], 1)[0])
    level -= 1  # making it zero indexed
    negative_sample = random.sample(PATTERNS['level_detection'], 1)[0]
    negative_sample = negative_sample.replace('{node}', node)

    if level != 0:
      level_var = random.sample([-1, 1], 1)[0]
    else:
      level_var = 1

    negative_sample = negative_sample.replace('{level}', str(level + level_var))
    negative_sample = replace_prompt_utils_structural(
        negative_sample, self.tree.post_comments_tree
    )

    negative_samples.append(
        [self.tree.post_id, task_name, negative_sample, 'No']
    )

    return positive_samples, negative_samples

  def get_num_leaf_nodes_samples(self):
    """Generates positive and negative samples for the task of counting leaf nodes.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.

    Explanation:  A leaf node is a comment that has no further replies.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'num_leaf_nodes'
    num_leaf_nodes = len(self.tree.leaf_nodes)

    # positive sample
    positive_sample = random.sample(PATTERNS['num_leaf_nodes'], 1)[0]

    if 'more than' in positive_sample:
      positive_sample = positive_sample.replace(
          '{num_leaf_nodes}', str(num_leaf_nodes - 1)
      )
    else:
      positive_sample = positive_sample.replace(
          '{num_leaf_nodes}', str(num_leaf_nodes)
      )

    positive_sample = replace_prompt_utils_structural(
        positive_sample, self.tree.post_comments_tree
    )
    positive_samples.append(
        [self.tree.post_id, task_name, positive_sample, 'Yes']
    )

    # negative sample
    negative_sample = random.sample(PATTERNS['num_leaf_nodes'], 1)[0]
    num_leaf_nodes_var1 = random.sample([-1, 1, 2, 3], 1)[0]
    num_leaf_nodes_var2 = random.sample([1, 2, 3], 1)[0]

    if 'more than' in negative_sample:
      negative_sample = negative_sample.replace(
          '{num_leaf_nodes}', str(num_leaf_nodes + num_leaf_nodes_var2)
      )
    else:
      negative_sample = negative_sample.replace(
          '{num_leaf_nodes}', str(num_leaf_nodes + num_leaf_nodes_var1)
      )

    negative_sample = replace_prompt_utils_structural(
        negative_sample, self.tree.post_comments_tree
    )
    negative_samples.append(
        [self.tree.post_id, task_name, negative_sample, 'No']
    )

    return positive_samples, negative_samples

  def get_long_chain_detection_samples(self):
    """Generates samples for detecting long chains of alternating replies between two users.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'long_chain_detection'
    explanation = (
        'Also we define a long chain of to and fro discussion between two users'
        ' A and B when there exists an instance where User B comments on a'
        " comment by User A followed by User A commenting on User B's comment"
        ' to his comment and so on.'
    )

    # positive sample
    for user_pair in self.tree.alternating_pairs:
      user1, user2 = str(user_pair[0]), str(user_pair[1])
      positive_sample = random.sample(PATTERNS['long_chain_detection'], 1)[0]
      positive_sample = positive_sample.replace('{user1}', user1)
      positive_sample = positive_sample.replace('{user2}', user2)
      positive_sample = positive_sample.replace(
          '{chain_length}', str(self.tree.chain_length)
      )

      positive_sample = replace_prompt_utils_structural(
          positive_sample, self.tree.post_comments_tree, explanation
      )
      positive_samples.append(
          [self.tree.post_id, task_name, positive_sample, 'Yes']
      )

    # negative sample
    if self.tree.num_users > 2 and not self.tree.alternating_pairs:
      user1, user2 = str(self.tree.num_users - 2), str(self.tree.num_users - 1)
      negative_sample = random.sample(PATTERNS['long_chain_detection'], 1)[0]
      negative_sample = negative_sample.replace('{user1}', user1)
      negative_sample = negative_sample.replace('{user2}', user2)
      negative_sample = negative_sample.replace(
          '{chain_length}',
          str(self.tree.chain_length),
      )

      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree, explanation
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_triangle_detection_samples(self):
    """Generates samples for detecting discussion triangles between three users.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'triangle_detection'
    explanation = (
        'Also we define a discussion between three users A, B and C as'
        ' triangular if there exists an instance where lets say User B comments'
        ' on a comment by User A and User C also comments on the same comment'
        " by User A and one of User B or C comments on each other's comment on"
        " User A's comment."
    )

    # positive sample
    for user_tuple in self.tree.discussion_triangles:
      user1, user2, user3 = (
          str(user_tuple[0]),
          str(user_tuple[1]),
          str(user_tuple[2]),
      )
      positive_sample = random.sample(PATTERNS['triangle_detection'], 1)[0]
      positive_sample = positive_sample.replace('{user1}', user1)
      positive_sample = positive_sample.replace('{user2}', user2)
      positive_sample = positive_sample.replace('{user3}', user3)

      positive_sample = replace_prompt_utils_structural(
          positive_sample, self.tree.post_comments_tree, explanation
      )
      positive_samples.append(
          [self.tree.post_id, task_name, positive_sample, 'Yes']
      )

    # negative sample
    if self.tree.num_users > 3 and not self.tree.discussion_triangles:
      user1, user2, user3 = (
          str(self.tree.num_users - 3),
          str(self.tree.num_users - 2),
          str(self.tree.num_users - 1),
      )
      negative_sample = random.sample(PATTERNS['triangle_detection'], 1)[0]
      negative_sample = negative_sample.replace('{user1}', user1)
      negative_sample = negative_sample.replace('{user2}', user2)
      negative_sample = negative_sample.replace('{user3}', user3)

      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree, explanation
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_leaf_node_detection_samples(self):
    """Generates samples for the task of detecting a node as leaf node in the comment tree.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'leaf_node_detection'

    # positive sample
    if self.tree.leaf_nodes:

      node = str(random.sample(self.tree.leaf_nodes, 1)[0])
      positive_sample = random.sample(PATTERNS['leaf_node_detection'], 1)[0]
      positive_sample = positive_sample.replace('{node}', node)
      positive_sample = replace_prompt_utils_structural(
          positive_sample, self.tree.post_comments_tree
      )
      positive_samples.append(
          [self.tree.post_id, task_name, positive_sample, 'Yes']
      )

    # negative sample
    if self.tree.intermediate_nodes:

      node = str(random.sample(self.tree.intermediate_nodes, 1)[0])
      negative_sample = random.sample(PATTERNS['leaf_node_detection'], 1)[0]
      negative_sample = negative_sample.replace('{node}', node)
      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_num_comments_by_user_samples(self):
    """Generates samples for the task of detecting the number of comments by a user.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'num_comments_by_user'

    users = list()
    all_users = set(self.tree.user2num_comments.keys())
    min_comments = 2

    for user in list(self.tree.user2num_comments.keys()):
      if self.tree.user2num_comments[user] > min_comments:
        users.append(user)

    if self.tree.num_users == 1 or not users:
      return positive_samples, negative_samples

    # positive sample
    user = str(random.sample(users, 1)[0])
    positive_sample = random.sample(PATTERNS['num_comments_by_user'], 1)[0]
    positive_sample = positive_sample.replace('{user}', user)
    positive_sample = positive_sample.replace(
        '{num_comments}', str(min_comments)
    )
    positive_sample = replace_prompt_utils_structural(
        positive_sample, self.tree.post_comments_tree
    )
    positive_samples.append(
        [self.tree.post_id, task_name, positive_sample, 'Yes']
    )

    # negative_sample
    rem_users = list(all_users - set(users))

    if rem_users:
      user = random.sample((rem_users), 1)[0]
      if self.tree.user2num_comments[user] > 1:
        negative_sample = PATTERNS['num_comments_by_user'][0]
      else:
        negative_sample = PATTERNS['num_comments_by_user'][1]

      negative_sample = negative_sample.replace('{user}', str(user))
      negative_sample = negative_sample.replace(
          '{num_comments}', str(min_comments)
      )
      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_most_first_level_replies_samples(self):
    """Generates samples for detecting comments with the most first-level replies.

    Returns:
        A tuple containing two lists:
            * positive_samples: A list of positive samples.
            * negative_samples: A list of negative samples.
    """

    positive_samples = []
    negative_samples = []
    task_name = 'most_first_level_replies'

    rem_nodes = list(
        set(self.tree.first_level_nodes) - set(self.tree.most_replied_nodes)
    )

    if self.tree.max_replies > 3:

      # positive sample
      node = str(random.sample(self.tree.most_replied_nodes, 1)[0])
      positive_sample = random.sample(PATTERNS['most_first_level_replies'], 1)[
          0
      ]
      positive_sample = positive_sample.replace('{node}', node)
      positive_sample = replace_prompt_utils_structural(
          positive_sample, self.tree.post_comments_tree
      )
      positive_samples.append(
          [self.tree.post_id, task_name, positive_sample, 'Yes']
      )

      # negative sample
      if rem_nodes:
        node = str(random.sample(rem_nodes, 1)[0])
        negative_sample = random.sample(
            PATTERNS['most_first_level_replies'], 1
        )[0]
        negative_sample = negative_sample.replace('{node}', node)
        negative_sample = replace_prompt_utils_structural(
            negative_sample, self.tree.post_comments_tree
        )
        negative_samples.append(
            [self.tree.post_id, task_name, negative_sample, 'No']
        )

    return positive_samples, negative_samples

  def get_pair_user_interaction_samples(self):
    """Generates samples for the task of detecting user interaction."""
    positive_samples = []
    negative_samples = []
    task_name = 'pair_user_interaction'

    # positive sample
    if len(self.tree.user_list) > 1:
      node_pair = random.sample(self.tree.user_list, 1)[0]
      positive_sample = random.sample(PATTERNS['pair_user_interaction'], 1)[0]
      positive_sample = positive_sample.replace('{user1}', str(node_pair[0]))
      positive_sample = positive_sample.replace('{user2}', str(node_pair[1]))
      positive_sample = replace_prompt_utils_structural(
          positive_sample, self.tree.post_comments_tree
      )
      positive_samples.append(
          [self.tree.post_id, task_name, positive_sample, 'Yes']
      )

    # negative sample
    non_existent_user_pairs = self.tree.find_nonexistent_edges(
        self.tree.user_list
    )

    if non_existent_user_pairs:

      node_pair = random.sample(non_existent_user_pairs, 1)[0]
      negative_sample = random.sample(PATTERNS['pair_user_interaction'], 1)[0]
      negative_sample = negative_sample.replace('{user1}', str(node_pair[0]))
      negative_sample = negative_sample.replace('{user2}', str(node_pair[1]))
      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_nodes_at_same_level_samples(self):
    """Generates samples for the task of detecting nodes at the same level."""
    positive_samples = []
    negative_samples = []
    task_name = 'nodes_at_same_level'

    if len(self.tree.level2nodes) > 3:
      levels = random.sample(list(self.tree.level2nodes), 3)
    else:
      return positive_samples, negative_samples

    # positive_sample
    for i in range(1, 3):
      if len(self.tree.level2nodes[levels[i]]) > 2:
        nodes = random.sample(self.tree.level2nodes[levels[i]], 2)
        positive_sample = random.sample(PATTERNS['nodes_at_same_level'], 1)[0]
        positive_sample = positive_sample.replace('{node1}', str(nodes[0]))
        positive_sample = positive_sample.replace('{node2}', str(nodes[1]))
        positive_sample = replace_prompt_utils_structural(
            positive_sample, self.tree.post_comments_tree
        )
        positive_samples.append(
            [self.tree.post_id, task_name, positive_sample, 'Yes']
        )

    # negative_sample
    if self.tree.level2nodes[levels[2]]:
      node1 = random.sample(self.tree.level2nodes[levels[1]], 1)[0]
      node2 = random.sample(self.tree.level2nodes[levels[2]], 1)[0]
      if (node1 == 0 or node2 == 0) and random.sample([0, 1], 1)[0] == 0:
        for unused_i in range(1, 5):
          node1 = random.sample(self.tree.level2nodes[levels[1]], 1)[0]
          node2 = random.sample(self.tree.level2nodes[levels[2]], 1)[0]

          if node1 != 0 and node2 != 0:
            break

      negative_sample = random.sample(PATTERNS['nodes_at_same_level'], 1)[0]
      negative_sample = negative_sample.replace('{node1}', str(node1))
      negative_sample = negative_sample.replace('{node2}', str(node2))
      negative_sample = replace_prompt_utils_structural(
          negative_sample, self.tree.post_comments_tree
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_level_and_num_children_samples(self):
    """Gets samples for the level and number of children task."""

    positive_samples = []
    negative_samples = []
    task_name = 'level_and_num_children'

    # self.tree.level2nodes.keys()

    if 1 in self.tree.multihop_neighbors:
      for comment in self.tree.multihop_neighbors[1]:
        num_children = len(self.tree.multihop_neighbors[1][comment])
        node_level = -1
        for level in self.tree.level2nodes:
          if comment in self.tree.level2nodes[level]:
            node_level = level
            break

        if node_level == -1:
          continue

        node_level -= 1  # making it zero indexed
        # positive_sample

        positive_sample = random.sample(PATTERNS['level_and_num_children'], 1)[
            0
        ]
        node = str(comment)
        positive_sample = positive_sample.replace('{node}', node)
        positive_sample = positive_sample.replace('{level}', str(node_level))

        if num_children > 1:
          if num_children <= 3:
            num_comments = 'one'
            positive_sample = positive_sample.replace(
                '{num_comments}', num_comments
            )
          else:
            num_comments = 'three'
            positive_sample = positive_sample.replace(
                '{num_comments}', num_comments
            )

          positive_sample = replace_prompt_utils_structural(
              positive_sample, self.tree.post_comments_tree
          )
          positive_samples.append(
              [self.tree.post_id, task_name, positive_sample, 'Yes']
          )

        # negative_sample
        negative_sample = random.sample(PATTERNS['level_and_num_children'], 1)[
            0
        ]
        node = str(comment)
        negative_sample = negative_sample.replace('{node}', node)
        negative_sample = negative_sample.replace('{level}', str(node_level))

        if num_children <= 1:
          num_comments = 'one'
          negative_sample = negative_sample.replace(
              '{num_comments}', num_comments
          )
          negative_sample = replace_prompt_utils_structural(
              negative_sample, self.tree.post_comments_tree
          )
          negative_samples.append(
              [self.tree.post_id, task_name, negative_sample, 'No']
          )

        elif num_children <= 3:
          num_comments = 'three'
          negative_sample = negative_sample.replace(
              '{num_comments}', num_comments
          )
          negative_sample = replace_prompt_utils_structural(
              negative_sample, self.tree.post_comments_tree
          )
          negative_samples.append(
              [self.tree.post_id, task_name, negative_sample, 'No']
          )

    return positive_samples, negative_samples

  def get_two_users_reply_same_comment_samples(self):
    """Generates samples for detecting two users reply to the same comment."""
    positive_samples = []
    negative_samples = []
    task_name = 'two_users_reply_same_comment'

    for comment in self.tree.comment2children_user:
      if comment != 0:

        # positive sample
        if len(self.tree.comment2children_user[comment]) > 1:

          users = random.sample(
              list(self.tree.comment2children_user[comment]), 2
          )
          user1, user2 = users
          node = comment

          positive_sample = random.sample(
              PATTERNS['two_users_reply_same_comment'], 1
          )[0]
          positive_sample = positive_sample.replace('{user1}', str(user1))
          positive_sample = positive_sample.replace('{user2}', str(user2))
          positive_sample = positive_sample.replace('{node}', str(node))
          positive_sample = replace_prompt_utils_structural(
              positive_sample, self.tree.post_comments_tree
          )
          positive_samples.append(
              [self.tree.post_id, task_name, positive_sample, 'Yes']
          )

        # negative sample
        if self.tree.comment2children_user[comment]:
          all_users = set(list(range(0, self.tree.num_users)))
          rem_users = all_users - self.tree.comment2children_user[comment]

          user1 = random.sample(
              list(self.tree.comment2children_user[comment]), 1
          )[0]
          if rem_users:
            user2 = random.sample(list(rem_users), 1)[0]
          else:
            continue

          node = comment

          negative_sample = random.sample(
              PATTERNS['two_users_reply_same_comment'], 1
          )[0]
          negative_sample = negative_sample.replace('{user1}', str(user1))
          negative_sample = negative_sample.replace('{user2}', str(user2))
          negative_sample = negative_sample.replace('{node}', str(node))
          negative_sample = replace_prompt_utils_structural(
              negative_sample, self.tree.post_comments_tree
          )
          negative_samples.append(
              [self.tree.post_id, task_name, negative_sample, 'No']
          )

    return positive_samples, negative_samples

  def get_user_reply_two_users_samples(self):
    """Generates user reply label task."""
    positive_samples = []
    negative_samples = []
    task_name = 'user_reply_two_users'

    for comment1 in self.tree.comment2children_user:
      for comment2 in self.tree.comment2children_user:

        user1 = self.tree.comment2user[comment1]
        user2 = self.tree.comment2user[comment2]

        if user1 != user2:
          replied_users1 = self.tree.comment2children_user[comment1]
          replied_users2 = self.tree.comment2children_user[comment2]
          common_users = replied_users1.intersection(replied_users2)
          only_comment1_users = replied_users1 - replied_users2

          # positive sample

          if common_users:
            positive_sample = random.sample(
                PATTERNS['user_reply_two_users'], 1
            )[0]
            user = random.sample(list(common_users), 1)[0]
            positive_sample = positive_sample.replace('{user}', str(user))
            positive_sample = positive_sample.replace('{user1}', str(user1))
            positive_sample = positive_sample.replace('{user2}', str(user2))
            positive_sample = replace_prompt_utils_structural(
                positive_sample, self.tree.post_comments_tree
            )
            positive_samples.append(
                [self.tree.post_id, task_name, positive_sample, 'Yes']
            )

          # negative sample

          if only_comment1_users:
            negative_sample = random.sample(
                PATTERNS['user_reply_two_users'], 1
            )[0]
            user = random.sample(list(only_comment1_users), 1)[0]

            flag = True
            for comment in self.tree.comment2children_user:
              if (
                  user2 == self.tree.comment2user[comment]
                  and user in self.tree.comment2children_user[comment]
              ):
                flag = False
                break

            if flag:
              negative_sample = negative_sample.replace('{user}', str(user))
              negative_sample = negative_sample.replace('{user1}', str(user1))
              negative_sample = negative_sample.replace('{user2}', str(user2))
              negative_sample = replace_prompt_utils_structural(
                  negative_sample, self.tree.post_comments_tree
              )
              negative_samples.append(
                  [self.tree.post_id, task_name, negative_sample, 'No']
              )

    return positive_samples, negative_samples

  def get_fighting_trait_samples(self):
    """Get samples for fighting label."""
    positive_samples = []
    negative_samples = []
    task_name = 'fighting_trait'

    user_pair_counts = {}

    for user_pair in self.tree.user_list:
      user_pair = tuple(user_pair)
      if user_pair not in user_pair_counts:
        user_pair_counts[user_pair] = 0
      user_pair_counts[user_pair] += 1

    for user_pair in user_pair_counts:
      user2, user1 = user_pair
      if user_pair_counts[user_pair] > 2:

        positive_sample = random.sample(PATTERNS['fighting_trait'], 1)[0]
        positive_sample = positive_sample.replace('{user1}', str(user1))
        positive_sample = positive_sample.replace('{user2}', str(user2))
        positive_sample = replace_prompt_utils_structural(
            positive_sample, self.tree.post_comments_tree
        )
        positive_samples.append(
            [self.tree.post_id, task_name, positive_sample, 'Yes']
        )

      elif user_pair[0] != 0:

        negative_sample = random.sample(PATTERNS['fighting_trait'], 1)[0]
        negative_sample = negative_sample.replace('{user1}', str(user1))
        negative_sample = negative_sample.replace('{user2}', str(user2))
        negative_sample = replace_prompt_utils_structural(
            negative_sample, self.tree.post_comments_tree
        )
        negative_samples.append(
            [self.tree.post_id, task_name, negative_sample, 'No']
        )

    return positive_samples, negative_samples

  def get_fakeddit_samples(
      self,
      sample_type,
  ):
    """Generates positive and negative samples for the Fakeddit task."""

    positive_samples = []
    negative_samples = []
    task_name = 'fakeddit'

    num_comments = (
        len(self.tree.leaf_nodes) + len(self.tree.intermediate_nodes) + 1
    )

    if num_comments > 5:
      if self.tree.label == 0:
        positive_sample = random.sample(PATTERNS['fakeddit'][sample_type], 1)[0]
        positive_sample = replace_prompt_utils_tree_classification(
            positive_sample, self.tree, sample_type=sample_type
        )
        positive_samples.append(
            [self.tree.post_id, task_name, positive_sample, 'Yes']
        )
      else:
        negative_sample = random.sample(PATTERNS['fakeddit'][sample_type], 1)[0]
        negative_sample = replace_prompt_utils_tree_classification(
            negative_sample, self.tree, sample_type=sample_type
        )
        negative_samples.append(
            [self.tree.post_id, task_name, negative_sample, 'No']
        )

    return positive_samples, negative_samples


  def get_controversy_samples(
      self,
      sample_type,
  ):
    """Get pos/neg samples for controversy task."""

    positive_samples = []
    negative_samples = []
    if self.tree.time_stamp == -1:
      task_name = 'controversy'
    else:
      task_name = 'controversy_' + str(self.tree.time_stamp) + '_hr'

    if self.tree.label == 1:
      positive_sample = random.sample(PATTERNS['controversy'][sample_type], 1)[
          0
      ]
      positive_sample = replace_prompt_utils_tree_classification(
          positive_sample, self.tree, sample_type=sample_type
      )
      positive_samples.append(
          [self.tree.post_id, task_name, positive_sample, 'Yes']
      )
    else:
      negative_sample = random.sample(PATTERNS['controversy'][sample_type], 1)[
          0
      ]
      negative_sample = replace_prompt_utils_tree_classification(
          negative_sample, self.tree, sample_type=sample_type
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples

  def get_cmv_samples(
      self,
      sample_type,
  ):
    """Generates positive and negative samples for the CMV task."""

    positive_samples = []
    negative_samples = []
    task_name = 'cmv'

    positive_users = set([
        self.tree.user_id_map[user]
        for user in self.tree.label
        if user in self.tree.user_id_map
    ])
    num_labels = len(positive_users)
    all_users = set([i for i in range(self.tree.num_users) if i != 0])
    negative_users = all_users - positive_users

    num_comments = (
        len(self.tree.leaf_nodes) + len(self.tree.intermediate_nodes) + 1
    )

    if num_comments > 5 and len(all_users) > num_labels + 1 and num_labels > 0:
      # for positive_user in positive_users:
      positive_sample = random.sample(PATTERNS['cmv'][sample_type], 1)[0]
      positive_user = random.sample(list(positive_users), 1)[0]
      positive_sample = positive_sample.replace('{user1}', str(positive_user))
      positive_sample = replace_prompt_utils_tree_classification(
          positive_sample, self.tree, sample_type=sample_type
      )
      positive_samples.append(
          [self.tree.post_id, task_name, positive_sample, 'Yes']
      )

      # for negative_user in negative_users:
      negative_sample = random.sample(PATTERNS['cmv'][sample_type], 1)[0]
      negative_user = random.sample(list(negative_users), 1)[0]
      negative_sample = negative_sample.replace('{user1}', str(negative_user))
      negative_sample = replace_prompt_utils_tree_classification(
          negative_sample, self.tree, sample_type=sample_type
      )
      negative_samples.append(
          [self.tree.post_id, task_name, negative_sample, 'No']
      )

    return positive_samples, negative_samples


def get_train_test_split(data, train_ratio=0.7):
  train, test = model_selection.train_test_split(
      data, test_size=1 - train_ratio, shuffle=True
  )
  val, test = model_selection.train_test_split(
      test, test_size=0.5, shuffle=True
  )

  return train, test, val


def write_to_csv(filename, data):
  f = open(filename, 'a')
  csv_output = csv.writer(f)
  if len(data[0]) == 2:
    csv_output.writerow(['inputs', 'targets'])
  else:
    csv_output.writerow(['post_id', 'task_name', 'inputs', 'targets'])
  csv_output.writerows(data)
  f.close()


# Pushshift data preparation.
top_posts = {}
cnt = 0
with open('RS_2019-04.zst', 'rb') as fh:
  dctx = zstd.ZstdDecompressor()
  stream_reader = dctx.stream_reader(fh)
  text_stream = io.TextIOWrapper(stream_reader, encoding='utf-8')
  for line in text_stream:
    if cnt > 100000:
      break
    obj = json.loads(line)
    fields = set(obj.keys())
    if 'selftext' in fields:
      if (
          obj['num_comments'] >= 3
          and len(obj['selftext'].split()) > 10
          and obj['num_comments'] <= 20
      ):
        post = {
            'id': 't3_' + obj['id'],
            'content': obj['title'] + '\n' + obj['selftext'],
            'time': obj['created_utc'],
            'user_id': obj['author_fullname'],
        }
        top_posts[post['id']] = post
        cnt += 1


top_posts_id = set(top_posts.keys())
comments = []
with open('RC_2019-04.zst', 'rb') as fh:
  dctx = zstd.ZstdDecompressor()
  stream_reader = dctx.stream_reader(fh)
  text_stream = io.TextIOWrapper(stream_reader, encoding='utf-8')

  for line in tqdm.tqdm(text_stream):
    obj = json.loads(line)
    if obj['link_id'] in top_posts_id:
      comment = [
          obj['parent_id'],
          't1_' + obj['id'],
          obj['body'],
          obj['created_utc'],
          obj['author_fullname'],
      ]
      comments.append(comment)

df = pd.DataFrame(comments)
df.columns = ['parent', 'node_id', 'content', 'time', 'user_id']
df = df.set_index('node_id')
df_dict = df.to_dict('index')

top_posts_new = {}

for i in top_posts:
  top_posts_new[i] = {
      'parent': 'NULL',
      'time': top_posts[i]['time'],
      'user_id': top_posts[i]['user_id'],
      'content': top_posts[i]['content'],
  }

df_dict.update(top_posts_new)

if not os.path.isdir('structural'):
  os.mkdir('structural')
np.save('./structural/node_info.npy', df_dict)

node_info = np.load('./structural/node_info.npy', allow_pickle=True).item()

parent_child = []

for i in node_info:
  parent_child.append([node_info[i]['time'], i, node_info[i]['parent']])

trees = build_trees(
    node_info, parent_child, time_index=0, child_index=1, parent_index=2
)

already_taken_posts = set(trees.keys())

all_tasks = [
    'k_hop',
    'subtree',
    'num_children',
    'depth',
    'level_detection',
    'num_leaf_nodes',
    'long_chain_detection',
    'triangle_detection',
    'leaf_node_detection',
    'num_comments_by_user',
    'most_first_level_replies',
    'pair_user_interaction',
    'nodes_at_same_level',
    'two_users_reply_same_comment',
    'user_reply_two_users',
    'fighting_trait',
]


def get_samples(trees, node_info, all_tasks):
  """Generates positive and negative samples for various tasks related to analyzing comment trees.

  Args:
      trees: A dictionary containing post data (structure depends on your
        specific data format).
      node_info: Additional data or information that might be required by the
        Tree and Task classes.
      all_tasks: A list of task names for which samples are to be generated.

  Returns:
      A tuple containing two lists:
          * positive_samples: A list of positive samples across all tasks.
          * negative_samples: A list of negative samples across all tasks.
  """
  samples = {}
  cnt = 0

  for post in trees:

    cnt += 1
    tree = Tree(post, trees[post], node_info, max_hops=10, truncate=False)
    tasks = Task(tree)

    for task_name in all_tasks:
      task_func = 'get_' + task_name + '_samples'

      if task_name == 'k_hop':
        for hop in range(1, 4):
          hop_task_name = str(hop) + '_' + 'hop'
          if hop_task_name not in samples:
            samples[hop_task_name] = {'positive': [], 'negative': []}

          task_positive_samples, task_negative_samples = getattr(
              tasks, task_func
          )(hop)
          samples[hop_task_name]['positive'].extend(task_positive_samples)
          samples[hop_task_name]['negative'].extend(task_negative_samples)
      else:
        if task_name not in samples:
          samples[task_name] = {'positive': [], 'negative': []}
        task_positive_samples, task_negative_samples = getattr(
            tasks, task_func
        )()
        samples[task_name]['positive'].extend(task_positive_samples)
        samples[task_name]['negative'].extend(task_negative_samples)

  return samples


samples = get_samples(trees, node_info, all_tasks)


max_samples = 10000
num_class = 2
train_samples = []
test_samples = []
val_samples = []

for task_name in samples:
  random.shuffle(samples[task_name]['positive'])
  random.shuffle(samples[task_name]['negative'])

  num_samples_per_class = min(
      int(max_samples / num_class),
      min(
          len(samples[task_name]['negative']),
          len(samples[task_name]['positive']),
      ),
  )

  positive_samples = random.sample(
      samples[task_name]['positive'], num_samples_per_class
  )
  negative_samples = random.sample(
      samples[task_name]['negative'], num_samples_per_class
  )

  train, test, val = get_train_test_split(positive_samples)
  train_samples.extend(train)
  test_samples.extend(test)
  val_samples.extend(val)

  train, test, val = get_train_test_split(negative_samples)
  train_samples.extend(train)
  test_samples.extend(test)
  val_samples.extend(val)


random.shuffle(train_samples)
random.shuffle(test_samples)
random.shuffle(val_samples)

write_to_csv('./structural/train.csv', train_samples)
write_to_csv('./structural/test.csv', test_samples)
write_to_csv('./structural/val.csv', val_samples)


# Fakeddit data preparation.

df = pd.read_csv('fakeddit_all_samples/all_comments.tsv', delimiter='\t')
df_train = pd.read_csv('fakeddit_all_samples/all_train.tsv', delimiter='\t')
df_test = pd.read_csv(
    'fakeddit_all_samples/all_test_public.tsv', delimiter='\t'
)
df_val = pd.read_csv('fakeddit_all_samples/all_validate.tsv', delimiter='\t')

top_posts = set()
for i in df['parent_id'].to_list():
  if str(i) != 'nan' and i[0:2] == 't3':
    top_posts.add(i)

top_post_ids = []
for i in top_posts:
  top_post_ids.append(i[3:])

train_ids = set(df_train['id'].to_list())
test_ids = set(df_test['id'].to_list())
val_ids = set(df_val['id'].to_list())

mask = df_train['id'].isin(top_post_ids)
train_top_posts = df_train[mask]

mask = df_test['id'].isin(top_post_ids)
test_top_posts = df_test[mask]


mask = df_val['id'].isin(top_post_ids)
val_top_posts = df_val[mask]

train_top_posts = train_top_posts[[
    'clean_title',
    'id',
    'title',
    '2_way_label',
    '3_way_label',
    '6_way_label',
    'author',
    'created_utc',
]]
test_top_posts = test_top_posts[[
    'clean_title',
    'id',
    'title',
    '2_way_label',
    '3_way_label',
    '6_way_label',
    'author',
    'created_utc',
]]
val_top_posts = val_top_posts[[
    'clean_title',
    'id',
    'title',
    '2_way_label',
    '3_way_label',
    '6_way_label',
    'author',
    'created_utc',
]]

frames = [train_top_posts, test_top_posts, val_top_posts]
top_posts = pd.concat(frames)

df_comments = df[['id', 'body', 'parent_id', 'author']]

top_posts = top_posts.set_index('id')
top_posts_dict = top_posts.to_dict('index')

df_comments = df_comments.set_index('id')
df_comments = df_comments[~df_comments.index.duplicated(keep='first')]
comments_dict = df_comments.to_dict('index')

node_info = {}
for i in top_posts_dict:
  node_info['t3_' + i] = {
      'parent': 'NULL',
      'content': top_posts_dict[i]['title'],
      'user_id': top_posts_dict[i]['author'],
  }

for i in comments_dict:
  if str(i) != 'nan':
    node_info['t1_' + i] = {
        'parent': comments_dict[i]['parent_id'],
        'content': comments_dict[i]['body'],
        'user_id': comments_dict[i]['author'],
    }


labels = {}
for i in top_posts_dict:
  labels['t3_' + i] = {
      '2_way_label': top_posts_dict[i]['2_way_label'],
      '3_way_label': top_posts_dict[i]['3_way_label'],
      '6_way_label': top_posts_dict[i]['6_way_label'],
      'created_utc': top_posts_dict[i]['created_utc'],
  }


if not os.path.isdir('fakeddit'):
  os.mkdir('fakeddit')

np.save('./fakeddit/node_info.npy', node_info)
np.save('./fakeddit/labels.npy', labels)


fakeddit_node_info = np.load(
    './fakeddit/node_info.npy', allow_pickle=True
).item()
fakeddit_labels = np.load('./fakeddit/labels.npy', allow_pickle=True).item()

fakeddit_parent_child = []

for i in fakeddit_node_info:
  fakeddit_parent_child.append([i, fakeddit_node_info[i]['parent']])

fakeddit_trees = build_trees(fakeddit_node_info, fakeddit_parent_child)

samples = {}
samples['post'] = {'positive': [], 'negative': []}
samples['post_comments'] = {'positive': [], 'negative': []}
samples['post_comments_tree'] = {'positive': [], 'negative': []}
cnt = 0

for post in fakeddit_trees:
  cnt += 1
  tree = Tree(
      post,
      fakeddit_trees[post],
      fakeddit_node_info,
      label=fakeddit_labels[post]['2_way_label'],
  )
  tasks = Task(tree)

  post_positive_samples, post_negative_samples = tasks.get_fakeddit_samples(
      sample_type='post'
  )

  samples['post']['positive'].extend(post_positive_samples)
  samples['post']['negative'].extend(post_negative_samples)

  post_comments_positive_samples, post_comments_negative_samples = (
      tasks.get_fakeddit_samples(sample_type='post_comments')
  )

  samples['post_comments']['positive'].extend(post_comments_positive_samples)
  samples['post_comments']['negative'].extend(post_comments_negative_samples)

  post_comments_tree_positive_samples, post_comments_tree_negative_samples = (
      tasks.get_fakeddit_samples(sample_type='post_comments_tree')
  )

  samples['post_comments_tree']['positive'].extend(
      post_comments_tree_positive_samples
  )
  samples['post_comments_tree']['negative'].extend(
      post_comments_tree_negative_samples
  )


time_stamps = []
for sample in samples['post']['positive']:
  time_stamps.append(fakeddit_labels['t3_' + sample[0][3:]]['created_utc'])

for sample in samples['post']['negative']:
  time_stamps.append(fakeddit_labels['t3_' + sample[0][3:]]['created_utc'])


time_stamps.sort()

train_time = time_stamps[int((len(time_stamps) * 70) / 100)]
val_time = time_stamps[int((len(time_stamps) * 85) / 100)]


def write_samples(
    sample_type,
):
  """Writes samples for a given sample type to a csv file."""
  random.shuffle(samples[sample_type]['positive'])
  random.shuffle(samples[sample_type]['negative'])

  train_samples = []
  test_samples = []
  val_samples = []

  for sample in samples[sample_type]['positive']:
    if fakeddit_labels['t3_' + sample[0][3:]]['created_utc'] <= train_time:
      train_samples.append(sample)
    elif fakeddit_labels['t3_' + sample[0][3:]]['created_utc'] <= val_time:
      val_samples.append(sample)
    else:
      test_samples.append(sample)

  for sample in samples[sample_type]['negative']:
    if fakeddit_labels['t3_' + sample[0][3:]]['created_utc'] <= train_time:
      train_samples.append(sample)
    elif fakeddit_labels['t3_' + sample[0][3:]]['created_utc'] <= val_time:
      val_samples.append(sample)
    else:
      test_samples.append(sample)

  random.shuffle(train_samples)
  random.shuffle(test_samples)
  random.shuffle(val_samples)

  if not os.path.isdir(f'./fakeddit/{sample_type}'):
    os.mkdir(f'./fakeddit/{sample_type}')
  write_to_csv(f'./fakeddit/{sample_type}/train.csv', train_samples)
  write_to_csv(f'./fakeddit/{sample_type}/test.csv', test_samples)
  write_to_csv(f'./fakeddit/{sample_type}/val.csv', val_samples)


write_samples('post')
write_samples('post_comments')
write_samples('post_comments_tree')

# Controversy data preparation.

subreddits = [
    'AskMen',
    'AskWomen',
    'Fitness',
    'LifeProTips',
    'personalfinance',
    'relationships',
]

cntp = 0
cntn = 0

top_posts = {}
labels = {}

for subreddit in subreddits:
  filepath = (
      'controversy_data_release/posts_from_paper/'
      + subreddit
      + '_posts_from_paper.jsonlist'
  )
  jsonObj = pd.read_json(path_or_buf=filepath, lines=True)

  for i, j in jsonObj.iterrows():
    top_posts[j['id']] = {
        'content': str(j['title']) + '\n' + str(j['selftext']),
        'user_id': j['author'],
        'time': j['created_utc'],
        'parent': 'NULL',
    }

    labels[j['id']] = j['controversy_label']

    if j['controversy_label']:
      cntp += 1
    else:
      cntn += 1

top_posts_id = set(top_posts.keys())


def dfs(child, node_info):
  """Gets node information for a given node."""
  node_info[child['id']] = {
      'content': str(child['body']),
      'user_id': child['author'],
      'time': child['created_utc'],
      'parent': child['parent_id'],
      'controversiality': child['controversiality'],
  }

  for child in child['children']:
    child_info = dfs(child, node_info)
    node_info.update(child_info)

  return node_info


node_info = {}
for subreddit in subreddits:
  filepath = (
      'controversy_data_release/15_comment_filtered/'
      + subreddit
      + '.jsonlist-15_comments'
  )
  f = open(filepath, 'r')
  data = [json.loads(line) for line in f]

  for post in data:
    if post['id'] in top_posts_id:
      node_info[post['id']] = {
          'content': str(post['title']) + '\n' + str(post['selftext']),
          'user_id': post['author'],
          'time': post['created_utc'],
          'parent': 'NULL',
      }

      for child in post['children']:
        child_info = dfs(child, {})
        node_info.update(child_info)

cnt = 0
for i in node_info:
  if i[0:2] == 't3':
    cnt += 1


with open(
    'controversy_data_release/15_comment_filtered/AskMen.jsonlist-15_comments'
) as f:
  data = [json.loads(line) for line in f]


if not os.path.isdir('controversy'):
  os.mkdir('controversy')


np.save('./controversy/node_info.npy', node_info)
np.save('./controversy/labels.npy', labels)


node_info = np.load('./controversy/node_info.npy', allow_pickle=True).item()
labels = np.load('./controversy/labels.npy', allow_pickle=True).item()

parent_child = []

for i in node_info:
  parent_child.append([float(node_info[i]['time']), i, node_info[i]['parent']])

trees = build_trees(
    node_info, parent_child, time_index=0, child_index=1, parent_index=2
)

sample_types = ['post', 'post_comments', 'post_comments_tree', 'post_tree']

samples = {}
samples['post'] = {'positive': [], 'negative': []}
samples['post_comments'] = {'positive': [], 'negative': []}
samples['post_comments_tree'] = {'positive': [], 'negative': []}
samples['post_tree'] = {'positive': [], 'negative': []}
cnt = 0

for post in trees:
  cnt += 1
  tree = Tree(post, trees[post], node_info, label=labels[post], truncate=True)
  tasks = Task(tree)

  for sample_type in sample_types:
    task_positive_samples, task_negative_samples = (
        tasks.get_controversy_samples(sample_type=sample_type)
    )
    samples[sample_type]['positive'].extend(task_positive_samples)
    samples[sample_type]['negative'].extend(task_negative_samples)

train_ids = []
val_ids = []
test_ids = []

path = './controversy'

for sample_type in sample_types:
  train_samples = []
  test_samples = []
  val_samples = []

  positive_samples = samples[sample_type]['positive']
  negative_samples = samples[sample_type]['negative']

  if not train_ids:

    train, test, val = get_train_test_split(positive_samples)
    train_samples.extend(train)
    test_samples.extend(test)
    val_samples.extend(val)

    train, test, val = get_train_test_split(negative_samples)
    train_samples.extend(train)
    test_samples.extend(test)
    val_samples.extend(val)

    train_ids = set([i[0] for i in train_samples])
    val_ids = set([i[0] for i in val_samples])
    test_ids = set([i[0] for i in test_samples])

  else:

    for sample in samples[sample_type]['positive']:
      if sample[0] in train_ids:
        train_samples.append(sample)
      elif sample[0] in val_ids:
        val_samples.append(sample)
      elif sample[0] in test_ids:
        test_samples.append(sample)
      else:
        pass

    for sample in samples[sample_type]['negative']:
      if sample[0] in train_ids:
        train_samples.append(sample)
      elif sample[0] in val_ids:
        val_samples.append(sample)
      elif sample[0] in test_ids:
        test_samples.append(sample)
      else:
        pass

  random.shuffle(train_samples)
  random.shuffle(test_samples)
  random.shuffle(val_samples)

  subdir = f'./controversy/{sample_type}'
  if not os.path.isdir(subdir):
    os.mkdir(subdir)
  write_to_csv(f'{subdir}/train.csv', train_samples)
  write_to_csv(f'{subdir}/test.csv', test_samples)
  write_to_csv(f'{subdir}/val.csv', val_samples)


node_info = np.load('./controversy/node_info.npy', allow_pickle=True).item()
labels = np.load('./controversy/labels.npy', allow_pickle=True).item()

parent_child = []

for i in node_info:
  parent_child.append([float(node_info[i]['time']), i, node_info[i]['parent']])

trees = build_trees(
    node_info, parent_child, time_index=0, child_index=1, parent_index=2
)


df = pd.read_csv('./controversy/post_comments_tree/val.csv')
post_ids = df['post_id'].tolist()

len(post_ids)

sample_types = ['post_comments_tree']

samples = {}
samples['post_comments_tree'] = {'positive': [], 'negative': []}

for time in [0, 1, 2, 4, 6, 12, 24]:
  cnt = 0
  for post in post_ids:
    cnt += 1
    tree = Tree(
        post,
        trees[post],
        node_info,
        label=labels[post],
        truncate=True,
        time_stamp=time,
    )
    tasks = Task(tree)

    for sample_type in sample_types:
      task_positive_samples, task_negative_samples = (
          tasks.get_controversy_samples(sample_type=sample_type)
      )
      samples[sample_type]['positive'].extend(task_positive_samples)
      samples[sample_type]['negative'].extend(task_negative_samples)

len(samples['post_comments_tree']['positive'])

positive_samples = samples['post_comments_tree']['positive']
negative_samples = samples['post_comments_tree']['negative']

all_samples = positive_samples
all_samples.extend(negative_samples)
random.shuffle(all_samples)

write_to_csv('./controversy/post_comments_tree/test_ed.csv', all_samples)

# CMV dataset preparation.


def dfs1(child, node_info):
  if 'body' in child:
    node_info['t1_' + child['id']] = {
        'content': str(child['body']),
        'user_id': child['author'],
        'time': child['created_utc'],
        'parent': child['parent_id'],
    }
  else:
    node_info = {}

  return node_info


filepath = 'cmv/all/train_period_data.jsonlist'
f = open(filepath, 'r')
data = [json.loads(line) for line in f]

node_info = {}

cnt = 0

for post in data:
  cnt += 1
  node_info['t3_' + post['id']] = {
      'content': str(post['title']) + '\n' + str(post['selftext']),
      'user_id': post['author'],
      'time': post['created_utc'],
      'parent': 'NULL',
  }

  for child in post['comments']:
    child_info = dfs1(child, {})
    node_info.update(child_info)

parent_child = []

for i in node_info:
  parent_child.append([float(node_info[i]['time']), i, node_info[i]['parent']])

trees = build_trees(
    node_info, parent_child, time_index=0, child_index=1, parent_index=2
)

len(trees)

nodes_to_remove = set()
labels = {}
for post_id in trees:
  if post_id not in labels:
    labels[post_id] = []

  for comments in trees[post_id]:
    comment_id = comments[0]
    if node_info[comment_id]['user_id'] == 'DeltaBot':
      if '/u/' in node_info[comment_id]['content']:
        label = node_info[comment_id]['content'].split('/u/')[1].split()[0]
        if '.' == label[-1]:
          label = label[:-1]
        if label:
          labels[post_id].append(label)
        else:
          pass
      nodes_to_remove.add(comment_id)
      nodes_to_remove.add(node_info[comment_id]['parent'])

cnt = 0
for i in labels:
  if labels[i]:
    cnt += 1


keys = list(labels.keys())
for i in keys:
  if not labels[i]:
    del labels[i]
    nodes_to_remove.add(i)
    for comments in trees[i]:
      nodes_to_remove.add(comments[0])

for node_id in nodes_to_remove:
  del node_info[node_id]

parent_child = []

for i in node_info:
  parent_child.append([float(node_info[i]['time']), i, node_info[i]['parent']])

trees = build_trees(
    node_info, parent_child, time_index=0, child_index=1, parent_index=2
)

x = 1000000000
y = 0
for i in trees:
  x = min(x, len(trees[i]))
  y = max(y, len(trees[i]))


np.save('./cmv/train_node_info.npy', node_info)
np.save('./cmv/train_labels.npy', labels)

filepath = 'cmv/all/heldout_period_data.jsonlist'
f = open(filepath, 'r')
data = [json.loads(line) for line in f]

node_info = {}

cnt = 0

for post in data:
  cnt += 1
  node_info['t3_' + post['id']] = {
      'content': str(post['title']) + '\n' + str(post['selftext']),
      'user_id': post['author'],
      'time': post['created_utc'],
      'parent': 'NULL',
  }

  for child in post['comments']:
    child_info = dfs1(child, {})
    node_info.update(child_info)

parent_child = []

for i in node_info:
  parent_child.append([float(node_info[i]['time']), i, node_info[i]['parent']])

trees = build_trees(
    node_info, parent_child, time_index=0, child_index=1, parent_index=2
)

len(trees)

nodes_to_remove = set()
labels = {}
for post_id in trees:
  if post_id not in labels:
    labels[post_id] = []

  for comments in trees[post_id]:
    comment_id = comments[0]
    if node_info[comment_id]['user_id'] == 'DeltaBot':
      if '/u/' in node_info[comment_id]['content']:
        label = node_info[comment_id]['content'].split('/u/')[1].split()[0]
        if '.' == label[-1]:
          label = label[:-1]
        if label != '':
          labels[post_id].append(label)
        else:
          pass
      nodes_to_remove.add(comment_id)
      nodes_to_remove.add(node_info[comment_id]['parent'])

cnt = 0
for i in labels:
  if labels[i]:
    cnt += 1

label_keys = list(labels.keys())
for i in label_keys:
  if not labels[i]:
    del labels[i]
    nodes_to_remove.add(i)
    for comments in trees[i]:
      nodes_to_remove.add(comments[0])

for node_id in nodes_to_remove:
  del node_info[node_id]

parent_child = []

for i in node_info:
  parent_child.append([float(node_info[i]['time']), i, node_info[i]['parent']])

trees = build_trees(
    node_info, parent_child, time_index=0, child_index=1, parent_index=2
)

x = 1000000000
y = 0
for i in trees:
  x = min(x, len(trees[i]))
  y = max(y, len(trees[i]))


np.save('./cmv/test_node_info.npy', node_info)
np.save('./cmv/test_labels.npy', labels)


train_node_info = np.load('./cmv/train_node_info.npy', allow_pickle=True).item()
test_node_info = np.load('./cmv/test_node_info.npy', allow_pickle=True).item()
train_labels = np.load('./cmv/train_labels.npy', allow_pickle=True).item()
test_labels = np.load('./cmv/test_labels.npy', allow_pickle=True).item()

train_parent_child = []

for i in train_node_info:
  train_parent_child.append(
      [float(train_node_info[i]['time']), i, train_node_info[i]['parent']]
  )

train_trees = build_trees(
    train_node_info,
    train_parent_child,
    time_index=0,
    child_index=1,
    parent_index=2,
)

test_parent_child = []

for i in test_node_info:
  test_parent_child.append(
      [float(test_node_info[i]['time']), i, test_node_info[i]['parent']]
  )

test_trees = build_trees(
    test_node_info,
    test_parent_child,
    time_index=0,
    child_index=1,
    parent_index=2,
)

sample_types = ['post_comments_tree']

train_samples = {}
train_samples['post_comments_tree'] = {'positive': [], 'negative': []}
cnt = 0

for post in test_trees:
  cnt += 1
  tree = Tree(
      post,
      test_trees[post],
      test_node_info,
      label=test_labels[post],
      truncate=True,
      max_length=15,
  )
  tasks = Task(tree)

  for sample_type in sample_types:
    task_positive_samples, task_negative_samples = tasks.get_cmv_samples(
        sample_type=sample_type
    )
    train_samples[sample_type]['positive'].extend(task_positive_samples)
    train_samples[sample_type]['negative'].extend(task_negative_samples)


test_samples = {}
test_samples['post_comments_tree'] = {'positive': [], 'negative': []}
cnt = 0

for post in train_trees:
  cnt += 1
  tree = Tree(
      post,
      train_trees[post],
      train_node_info,
      label=train_labels[post],
      truncate=True,
      max_length=15,
  )
  tasks = Task(tree)

  for sample_type in sample_types:
    task_positive_samples, task_negative_samples = tasks.get_cmv_samples(
        sample_type=sample_type
    )
    test_samples[sample_type]['positive'].extend(task_positive_samples)
    test_samples[sample_type]['negative'].extend(task_negative_samples)


os.mkdir('./cmv/post_comments_tree')

pct_test_samples = []
for sample in test_samples['post_comments_tree']['positive']:
  pct_test_samples.append(sample)

for sample in train_samples['post_comments_tree']['negative']:
  pct_test_samples.append(sample)

random.shuffle(test_samples)
write_to_csv('./cmv/post_comments_tree/test.csv', pct_test_samples)

positive_samples = train_samples['post_comments_tree']['positive']
negative_samples = train_samples['post_comments_tree']['negative']

train_positive_samples, val_positive_samples = model_selection.train_test_split(
    positive_samples, test_size=0.1, shuffle=True
)
train_negative_samples, val_negative_samples = model_selection.train_test_split(
    negative_samples, test_size=0.1, shuffle=True
)

train_samples = []
val_samples = []

train_samples.extend(train_positive_samples)
val_samples.extend(val_positive_samples)

train_samples.extend(train_negative_samples)
val_samples.extend(val_negative_samples)

write_to_csv('./cmv/post_comments_tree/train.csv', train_samples)
write_to_csv('./cmv/post_comments_tree/val.csv', val_samples)
