import re
import random
import string

from_pattern = re.compile("FROM(.*)WHERE")

where_pattern = re.compile("WHERE(.*)")

unfilteredCardinalities = {
    'cast_info': 36244344,
    'movie_info': 14835720,
    'movie_keyword': 4523930,
    'name': 4167491,
    'char_name': 3140339,
    'person_info': 2963664,
    'movie_companies': 2609129,
    'title': 2528312,
    'movie_info_idx': 1380035,
    'aka_name': 901343,
    'aka_title': 361472,
    'company_name': 234997,
    'complete_cast': 135086,
    'keyword': 134170,
    'movie_link': 29997,
    'info_type': 113,
    'link_type': 18,
    'role_type': 12,
    'kind_type': 7,
    'comp_cast_type': 4,
    'company_type': 4,
}

def is_connected_graph(nodes, edges):
    graph = build_graph(nodes, edges)
    visited = set()

    # Perform DFS traversal
    dfs(graph, nodes[0], visited)

    # Check if all nodes were visited
    return len(visited) == len(nodes)


def build_graph(nodes, edges):
    graph = {node: [] for node in nodes}

    for edge in edges:
        node1, node2 = edge
        if node1 in nodes and node2 in nodes:
            graph[node1].append(node2)
            graph[node2].append(node1)

    return graph


def dfs(graph, start, visited):
    visited.add(start)

    for neighbor in graph[start]:
        if neighbor not in visited:
            dfs(graph, neighbor, visited)

from itertools import chain, combinations

def powerset(lst):
    # Generate all combinations of length 0 to len(lst)
    subsets = chain.from_iterable(combinations(lst, r) for r in range(len(lst) + 1))

    # Filter out subsets with less than 2 elements
    filtered_subsets = [subset for subset in subsets if len(subset) >= 2]

    return filtered_subsets

import pymonetdb
connection = pymonetdb.connect(username="monetdb", password="monetdb", hostname="localhost", database="imd_noidx")
cursor = connection.cursor()

for i in range(1, 34):
    for repeat_time in range(100):
        print("i", i)
        print("repeat_time:", repeat_time)
        with open(f"result2/q{i}_{repeat_time}.sql", "r") as f:
            with open(f"result3/q{i}_{repeat_time}.json", "w") as f2:
                with open(f"log/q{i}_{repeat_time}.sql", "w") as f3:
                    sql = f.read().strip()
                    sql_single_line = sql.replace("\n", " ")
                    matches = from_pattern.findall(sql_single_line)
                    if matches:
                        tables_str = matches[0].strip()
                    join_table_aliases = re.findall(r'\b(\w+)\s+AS\s+(\w+)\b', tables_str)
                    join_tables = []
                    join_real_tables = []
                    for join_table_alias in join_table_aliases:
                        join_tables.append((join_table_alias[0], join_table_alias[1]))
                        join_real_tables.append(join_table_alias[1])
                    nr_join_table = len(join_tables)
                    # print("join_tables:", join_tables)
                    where_condition = where_pattern.findall(sql_single_line)[0].strip()
                    predicates = where_condition.strip().split("AND")
                    # print(predicates)
                    unary_predicates = []
                    join_predicates = []
                    join_predicates2 = []
                    join_attributes_old = set()
                    unary_attributes_old = set()
                    join_expressions_old = []
                    from collections import defaultdict
                    join_predicates_dict = defaultdict(list)
                    unary_predicates_dict = defaultdict(list)
                    for predicate in predicates:
                        predicate_s = predicate.strip()
                        match = re.search(r'^(\w+)\.(\w+) = (\w+)\.(\w+)$', predicate_s)
                        if match:
                            # print(match.group(1), match.group(2), match.group(3), match.group(4))
                            join_predicates.append((match.group(1), match.group(3), predicate_s))
                            join_predicates2.append((match.group(1), match.group(3)))
                            # print(match.group(1))
                            join_predicates_dict[match.group(1)].append((match.group(3), predicate_s))
                            join_predicates_dict[match.group(3)].append((match.group(1), predicate_s))
                            join_attributes_old.add(f'{match.group(1)}.{match.group(2)}')
                            join_attributes_old.add(f'{match.group(3)}.{match.group(4)}')
                            join_expressions_old.append(f'{{"left": "{match.group(1) + "." + match.group(2)}", "right": "{match.group(3) + "." + match.group(4)}"}}')
                            # if match.group(1) < match.group(3):
                            #     join_predicates_dict[f'{match.group(1)}={match.group(3)}'] = predicate_s
                            # else:
                            #     join_predicates_dict[f'{match.group(3)}={match.group(1)}'] = predicate_s
                        else:
                            # unary_predicates.append((predicate_s.split('.')[0], predicate_s))
                            # r'\(?(\w+)\.(\w+)'
                            # unary_matches = re.search(r"\(?(\w+?)\.(\w+?)\s+", predicate_s)
                            unary_matches = re.findall(r'\(?(\w+?)\.(\w+) ', predicate_s)
                            # print(predicate_s)
                            # print(unary_matches)
                            unary_predicates.append((unary_matches[0][0], predicate_s))
                            unary_predicates_dict[unary_matches[0][0]].append(predicate_s)
                            unary_attributes_old.add(unary_matches[0][0]+"."+unary_matches[0][1])
                            # print(unary_attributes_old)
                    # print("unary_predicates:", unary_predicates)
                    # print("join_predicates:", join_predicates)
                    # first construct base tables cardinalities
                    f2.write('{\n')
                    f2.write(f' "name": "{i} {repeat_time}",\n')
                    f2.write(' "relations": [\n')
                    duplicate_tables = dict()
                    table_to_base_table = dict()
                    for join_idx in range(nr_join_table):
                        base_table_name, table_name = join_tables[join_idx]
                        table_to_base_table[table_name] = base_table_name
                        if base_table_name in duplicate_tables:
                            alias_table = f'{base_table_name}{duplicate_tables[base_table_name]}'
                            duplicate_tables[base_table_name] += 1
                        else:
                            alias_table = base_table_name
                            duplicate_tables[base_table_name] = 2
                        # print(unary_predicates_dict[table_name])
                        unfilteredCardinality = unfilteredCardinalities[base_table_name]
                        if len(unary_predicates_dict[table_name]) == 0:
                            unary_cardinality = unfilteredCardinality
                        else:
                            unary_info = " AND ".join(unary_predicates_dict[table_name])
                            unary_sql_statement = f"select count(*) from {base_table_name} AS {table_name} where {unary_info};"
                            f3.write(f"{unary_sql_statement}\n")
                            # run in monetdb
                            cursor.execute(unary_sql_statement)
                            unary_cardinality = cursor.fetchone()[0]
                            # parsed = sqlparse.parse(base_sql)
                            # assert len(parsed) == 1 and parsed[0].get_type() == 'SELECT'
                        if join_idx == nr_join_table - 1:
                            f2.write(
                                f'  {{"name": "{table_name}", "aliastable": "{alias_table}", "basetable": "{base_table_name}", "cardinality": {unary_cardinality}, "unfilteredCardinality": {unfilteredCardinality}}}\n ],\n')
                        else:

                            f2.write(f'  {{"name": "{table_name}", "aliastable": "{alias_table}", "basetable": "{base_table_name}", "cardinality": {unary_cardinality}, "unfilteredCardinality": {unfilteredCardinality}}},\n')
                    f2.write(' "joins": [\n')
                    nr_join_predicates = len(join_predicates)
                    for join_predicate_idx in range(nr_join_predicates):
                        join_predicate = join_predicates[join_predicate_idx]
                        if join_predicate_idx == nr_join_predicates - 1:
                            f2.write(f'  {{"relations": ["{join_predicate[0]}", "{join_predicate[1]}"]}}\n')
                        else:
                            f2.write(f'  {{"relations": ["{join_predicate[0]}", "{join_predicate[1]}"]}},\n')
                    f2.write(' ],\n')
                    f2.write(' "sizes": [\n')
                    # print(join_real_tables)
                    all_combinations = powerset(join_real_tables)
                    # print(all_combinations)
                    all_valid_combinations = []
                    nr_all_combinations = len(all_combinations)
                    for combination_idx in range(nr_all_combinations):
                        combination = all_combinations[combination_idx]
                        # print(combination)
                        # test whether remove cartesian product
                        # print(combination)
                        # print(join_predicates2)
                        if is_connected_graph(combination, join_predicates2):
                            all_valid_combinations.append(combination)
                    # print(all_valid_combinations)
                    nr_valid_combination = len(all_valid_combinations)
                    for combination_idx in range(nr_valid_combination):
                        valid_combination = all_valid_combinations[combination_idx]
                        # get all join predicates
                        # print(valid_combination)
                        # print(join_predicates_dict)
                        current_join_predicates = set()
                        current_unary_predicates = []
                        for table in valid_combination:
                            # print("table:", table)
                            all_join_predicates = join_predicates_dict[table]
                            for all_join_predicate in all_join_predicates:
                                if all_join_predicate[0] in valid_combination:
                                    current_join_predicates.add(all_join_predicate[1])
                            current_unary_predicates += unary_predicates_dict[table]
                        assert len(current_join_predicates) > 0
                        unary_predicate_info = " AND ".join(current_unary_predicates)
                        # print(current_join_predicates)
                        valid_combination_list = [f'\"{one_valid_combination}\"' for one_valid_combination in valid_combination]
                        relation_in_str = ", ".join(valid_combination_list)
                        # construct join string
                        table_info = " , ".join([f"{table_to_base_table[table]} AS {table}" for table in valid_combination])
                        # print(table_info)
                        join_info = " AND ".join(current_join_predicates)
                        if len(current_unary_predicates) > 0:
                            join_sql_statement = f"select count(*) from {table_info} where {join_info} AND {unary_predicate_info};"
                        else:
                            join_sql_statement = f"select count(*) from {table_info} where {join_info};"
                        f3.write(f"{join_sql_statement}\n")
                        # run in monetdb
                        cursor.execute(join_sql_statement)
                        join_cardinality = cursor.fetchone()[0]
                        # parsed = sqlparse.parse(sql_statement)
                        # assert len(parsed) == 1 and parsed[0].get_type() == 'SELECT'
                        if combination_idx == nr_valid_combination - 1:
                            f2.write(f'  {{"relations": [{relation_in_str}], "cardinality": {join_cardinality}}}\n')
                        else:
                            f2.write(f'  {{"relations": [{relation_in_str}], "cardinality": {join_cardinality}}},\n')

                    f2.write(' ],\n')
                    query_standard = re.sub(r'\s+', ' ', sql_single_line)
                    f2.write(f' "query": "{query_standard}",\n')
                    f2.write(f''' \"join columns\": [{", ".join([f'"{item}"' for item in join_attributes_old])}],\n''')
                    f2.write(f''' \"unary columns\": [{", ".join([f'"{item}"' for item in unary_attributes_old])}],\n''')
                    f2.write(f''' \"join expressions\": [{", ".join([f'{item}' for item in join_expressions_old])}]\n''')
                    f2.write('}')
                    f2.flush()




