import os
import torch

from pprint import pprint
from tqdm import tqdm

from setup_utils import set_seed
from src.dataset import load_dataset, DAGDataset
from src.eval import TPUTileEvaluator
from src.model import DiscreteDiffusion, EdgeDiscreteDiffusion, LayerDAG

import networkx as nx
from collections import defaultdict
import numpy as np
import time

def check_predecessor_balance(src_list, dst_list, x_n_list, rho):
    predecessors = defaultdict(list)
    for src, dst in zip(src_list, dst_list):
        predecessors[dst].append(src)
    for v, preds in predecessors.items():
        n_v_0 = sum(1 for u in preds if x_n_list[u] == 0)
        n_v_1 = sum(1 for u in preds if x_n_list[u] == 1)
        n_v = n_v_0 + n_v_1
        if n_v > 0:  # 只有当有前驱节点时才进行检查
            imbalance_ratio = np.floor(abs(n_v_0 - n_v_1) / 2) / (n_v / 2)
            if imbalance_ratio > rho:
                return False
    G = nx.DiGraph()
    G.add_edges_from(zip(src_list, dst_list))
    if not nx.is_directed_acyclic_graph(G):
        return False
    elif nx.number_weakly_connected_components(G) > 1:
        return False
    elif not len(G.nodes) == len(x_n_list):
        return False
    else:
        return True

def sample_tpu_subset(args, device, dummy_category, model, subset):
    syn_set = DAGDataset(dummy_category, label=True)

    raw_y_batch = []
    for i, y in enumerate(tqdm(subset.y)):
        raw_y_batch.append(y)
        if (len(raw_y_batch) == args.batch_size) or (i == len(subset.y) - 1):
            batch_edge_index, batch_x_n, batch_y = model.sample(
                device, len(raw_y_batch), raw_y_batch,
                min_num_steps_n=args.min_num_steps_n,
                max_num_steps_n=args.max_num_steps_n,
                min_num_steps_e=args.min_num_steps_e,
                max_num_steps_e=args.max_num_steps_e)

            for j in range(len(batch_edge_index)):
                edge_index_j = batch_edge_index[j]
                dst_j, src_j = edge_index_j.cpu()
                syn_set.add_data(src_j, dst_j, batch_x_n[j].cpu(),
                                 batch_y[j])

            raw_y_batch = []

    return syn_set

def sample_latent_preferential_subset(args, device, dummy_category, model, subset, mode='train'):
    syn_set = DAGDataset(dummy_category, label=False)
    for i, _ in enumerate(tqdm(subset.src)):
        if (i % args.batch_size == 0 and i != 0) or (i == len(subset) - 1):
            batch_edge_index, batch_x_n = model.sample(
                device, batch_size=args.batch_size, y=None,
                min_num_steps_n=args.min_num_steps_n,
                max_num_steps_n=args.max_num_steps_n,
                min_num_steps_e=args.min_num_steps_e,
                max_num_steps_e=args.max_num_steps_e,
                check=args.check, solve=args.solve, refine=args.refine, inner_refine=args.inner_refine,
                refine_with_transformer=args.refine_with_transformer, check_with_refine=args.check_with_refine, mode=mode)
            for j in range(len(batch_edge_index)):
                edge_index_j = batch_edge_index[j]
                dst_j, src_j = edge_index_j.cpu()
                syn_set.add_data(src_j, dst_j, batch_x_n[j].squeeze(dim=-1).cpu())
    return syn_set

def dump_to_file(syn_set, file_name, sample_dir):
    file_path = os.path.join(sample_dir, file_name)
    data_dict = {
        'src_list': [],
        'dst_list': [],
        'x_n_list': [],
        'y_list': []
    }
    for i in range(len(syn_set)):
        src_i, dst_i, x_n_i, y_i = syn_set[i]

        data_dict['src_list'].append(src_i)
        data_dict['dst_list'].append(dst_i)
        data_dict['x_n_list'].append(x_n_i)
        data_dict['y_list'].append(y_i)

    torch.save(data_dict, file_path)

def eval_tpu_tile(args, device, model):
    sample_dir = 'tpu_tile_samples'
    os.makedirs(sample_dir, exist_ok=True)

    evaluator = TPUTileEvaluator()
    train_set, val_set, _ = load_dataset('tpu_tile')

    train_syn_set = sample_tpu_subset(args, device, train_set.dummy_category, model, train_set)
    val_syn_set = sample_tpu_subset(args, device, train_set.dummy_category, model, val_set)

    evaluator.eval(train_syn_set, val_syn_set)

    dump_to_file(train_syn_set, 'train.pth', sample_dir)
    dump_to_file(val_syn_set, 'val.pth', sample_dir)
    
def eval_latent_preferential_tile(args, device, model):
    sample_dir = './latent_preferential_samples/samples_rho_{}'.format(args.rho)
    os.makedirs(sample_dir, exist_ok=True)
    train_set, val_set, _ = load_dataset('latent_preferential', rho=args.rho)
    start = time.perf_counter()
    train_syn_set = sample_latent_preferential_subset(args, device, train_set.dummy_category, model, train_set, mode='train')
    end1 = time.perf_counter()
    val_syn_set = sample_latent_preferential_subset(args, device, train_set.dummy_category, model, val_set, mode='val')
    end2 = time.perf_counter()
    print(f'use time: {end1 - start:.2f}s for train')
    print(f'use time: {end2 - end1:.2f}s for val')
    count = 0
    total = 0

    for i in range(len(train_syn_set)):
        src_i, dst_i, x_n_i = train_syn_set[i]
        if check_predecessor_balance(src_i.tolist(), dst_i.tolist(), x_n_i.tolist(), rho=args.rho):
            count += 1
        total += 1
    print(f'we have {count} samples successfully done for train')
    count = 0
    total = 0
    for i in range(len(val_syn_set)):
        src_i, dst_i, x_n_i = val_syn_set[i]
        if check_predecessor_balance(src_i.tolist(), dst_i.tolist(), x_n_i.tolist(), rho=args.rho):
            count += 1
        total += 1
    print(f'we have {count} samples successfully done for val')

    dump_to_file(train_syn_set, 'train.pth', sample_dir, label=False)
    dump_to_file(val_syn_set, 'val.pth', sample_dir, label=False)

def main(args):
    torch.set_num_threads(args.num_threads)

    device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_str)

    ckpt = torch.load(args.model_path)

    dataset = ckpt['dataset']
    # assert dataset == 'tpu_tile'

    node_diffusion = DiscreteDiffusion(**ckpt['node_diffusion_config'])
    edge_diffusion = EdgeDiscreteDiffusion(**ckpt['edge_diffusion_config'])

    model = LayerDAG(device=device,
                     node_diffusion=node_diffusion,
                     edge_diffusion=edge_diffusion,
                     **ckpt['model_config'],
                     load_refine_model=args.refine,
                     refine_with_transformer=args.refine_with_transformer, rho=args.rho)
    pprint(ckpt['model_config'])
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    set_seed(args.seed)

    if dataset == 'tpu_tile':
        eval_tpu_tile(args, device, model)
    elif dataset == 'latent_preferential':
        eval_latent_preferential_tile(args, device, model)

if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_path", type=str, help="Path to the model.", default='./model/model_latent_preferential.pth')
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--num_threads", type=int, default=24)
    parser.add_argument("--min_num_steps_n", type=int, default=None)
    parser.add_argument("--min_num_steps_e", type=int, default=None)
    parser.add_argument("--max_num_steps_n", type=int, default=None)
    parser.add_argument("--max_num_steps_e", type=int, default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--check", type=bool, default=False)
    parser.add_argument("--refine", type=bool, default=False)
    parser.add_argument('--solve', type=bool, default=False)
    parser.add_argument('--inner_refine', type=bool, default=False)
    parser.add_argument('--refine_with_transformer', type=bool, default=False)
    parser.add_argument('--check_with_refine', type=bool, default=False)
    parser.add_argument("--rho", type=float, default=0.5)
    args = parser.parse_args()

    main(args)
