from collections import defaultdict
import time
import networkx as nx
import numpy as np
import torch
import json
import random
from setup_utils import load_yaml
from src.dataset import load_dataset, LayerDAGEdgePredDataset
from src.model import EdgeDiscreteDiffusion


def check_predecessor_balance(src_list, dst_list, x_n_list, rho):
    predecessors = defaultdict(list)
    for src, dst in zip(src_list, dst_list):
        predecessors[dst].append(src)
    #print(src_list, dst_list, x_n_list)
    for v, preds in predecessors.items():
        n_v_0 = sum(1 for u in preds if x_n_list[u-1] == 0)
        n_v_1 = sum(1 for u in preds if x_n_list[u-1] == 1)
        n_v = n_v_0 + n_v_1
        if n_v > 0: 
            imbalance_ratio = np.floor(abs(n_v_0 - n_v_1) / 2) / (n_v / 2)
            if imbalance_ratio > rho:
                return False
    G = nx.DiGraph()
    G.add_edges_from(zip(src_list, dst_list))
    if not nx.is_directed_acyclic_graph(G):
        return False
    elif nx.number_weakly_connected_components(G) > 1:
        return False
    elif not len(G.nodes) == len(x_n_list):
        return False
    else:
        return True


def process_data(input_src, input_dst, noisy_src, noisy_dst, input_x_n, noisy_src_old, noisy_dst_old, t, rho):
    input_src = input_src.tolist()
    input_dst = input_dst.tolist()
    noisy_src = noisy_src.tolist()
    noisy_dst = noisy_dst.tolist()
    input_x_n = input_x_n.tolist()[1:]
    flag = False
    data = None
    check = check_predecessor_balance(src_list=input_src + noisy_src,
                                      dst_list=input_dst + noisy_dst,
                                      x_n_list=input_x_n,
                                      rho=rho)
    if noisy_src_old is not None and noisy_dst_old is not None:
        noisy_src_old = noisy_src_old.tolist()
        noisy_dst_old = noisy_dst_old.tolist()
        check_old = check_predecessor_balance(src_list=input_src + noisy_src_old,
                                              dst_list=input_dst + noisy_dst_old,
                                              x_n_list=input_x_n,
                                              rho=rho)
        assert check == True and check_old == False
        flag = True
    # elif check:
    #     if random.random() < 0.1:
    #         noisy_src_old = noisy_src
    #         noisy_dst_old = noisy_dst
    #         flag = True
    #         # print('add original data')
    if flag:
        data = {'input_src': input_src,
                      'input_dst': input_dst,
                      'noisy_src': noisy_src,
                      'noisy_dst': noisy_dst,
                      'noisy_src_old': noisy_src_old,
                      'noisy_dst_old': noisy_dst_old,
                      'input_x_n': input_x_n,
                      't': t.tolist(),
                      }
    return data

def main(args, iter):
    torch.set_num_threads(args.num_threads)
    config = load_yaml(args.config_file)
    dataset = config['general']['dataset']
    train_set, val_set, _ = load_dataset(dataset, rho=args.rho)
    train_edge_pred_dataset = LayerDAGEdgePredDataset(train_set, conditional=config['general']['conditional'], constrain=True)
    val_edge_pred_dataset = LayerDAGEdgePredDataset(val_set, conditional=config['general']['conditional'], constrain=True)

    edge_diffusion_config = {
        'avg_in_deg': train_edge_pred_dataset.avg_in_deg,
        'T': config['edge_pred']['T']
    }
    edge_diffusion = EdgeDiscreteDiffusion(**edge_diffusion_config, rho=args.rho)
    train_edge_pred_dataset.edge_diffusion = edge_diffusion
    val_edge_pred_dataset.edge_diffusion = edge_diffusion

    #for epoch in range(7,10):
    epoch = int(iter)
    count = 0
    for i in range(len(train_edge_pred_dataset)):
        data = train_edge_pred_dataset.__getitem__(i)
        if len(data) == 14:
            input_src, input_dst, noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n, \
                input_abs_level, input_rel_level, t, input_y, query_src, query_dst, label = data
        else:
            input_src, input_dst, noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n, \
                input_abs_level, input_rel_level, t, query_src, query_dst, label = data
        train_data = process_data(input_src, input_dst, noisy_src, noisy_dst, input_x_n, noisy_src_old, noisy_dst_old, t, rho=args.rho)
        if train_data is not None:
            count += 1
            with open(f'./data/train_{epoch}.jsonl', "a") as f:
                json.dump(train_data, f)
                f.write('\n')
            if count % 1000 == 0:
                print(f'generate {count} data in epoch {epoch}')
    print(f'finish train epoch {epoch}')

    for j in range(len(val_edge_pred_dataset)):
        data = val_edge_pred_dataset.__getitem__(j)
        if len(data) == 14:
            input_src, input_dst, noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n, \
                input_abs_level, input_rel_level, t, input_y, query_src, query_dst, label = data
        else:
            input_src, input_dst, noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n, \
                input_abs_level, input_rel_level, t, query_src, query_dst, label = data
        val_data = process_data(input_src, input_dst, noisy_src, noisy_dst, input_x_n, noisy_src_old, noisy_dst_old, t, rho=args.rho)
        if val_data is not None:
            count += 1
            with open(f'./data/val_{epoch}.jsonl', "a") as f:
                json.dump(val_data, f)
                f.write('\n')
            if count % 1000 == 0:
                print(f'generate {count} data in epoch {epoch}')
    print(f'finish val epoch {epoch}')


if __name__ == '__main__':
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument("--config_file", type=str, default='configs/LayerDAG/latent_preferential.yaml')
    parser.add_argument("--num_threads", type=int, default=16)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--rho", type=float, default=0.5)
    parser.add_argument("--iter", type=int, default=10)
    args = parser.parse_args()
    print("start generating smt data")
    for i in range(int(args.iter)):
        start = time.perf_counter()
        main(args, i)
        end = time.perf_counter()
        print(f'use time: {end - start:.2f}s for generate data')