import glob
import json
import os
import random
from collections import defaultdict
from copy import deepcopy
import time
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from setup_utils import set_seed
from src.model import EdgeRefineModelWithEdgeTransformer


def check_predecessor_balance(src_list, dst_list, x_n_list, rho):
    # 构建前驱字典，predecessors[v] 存储 v 的所有前驱节点
    predecessors = defaultdict(list)
    for src, dst in zip(src_list, dst_list):
        predecessors[dst].append(src)
    for v, preds in predecessors.items():
        n_v_0 = sum(1 for u in preds if x_n_list[u - 1] == 0)
        n_v_1 = sum(1 for u in preds if x_n_list[u - 1] == 1)
        n_v = n_v_0 + n_v_1
        if n_v > 0: 
            imbalance_ratio = np.floor(abs(n_v_0 - n_v_1) / 2) / (n_v / 2)
            if imbalance_ratio > rho:
                return False
    G = nx.DiGraph()
    G.add_edges_from(zip(src_list, dst_list))
    if not nx.is_directed_acyclic_graph(G):
        return False
    elif nx.number_weakly_connected_components(G) > 1:
        return False
    elif not len(G.nodes) == len(x_n_list):
        return False
    else:
        return True

def load_data(filenames):
    datas = []
    x_n_list = []
    unique_data_set = set()

    for filename in filenames:
        with open(filename, 'r') as file:
            for line in file:
                data = json.loads(line)
                data_str = json.dumps(data, sort_keys=True)
                if data_str not in unique_data_set:
                    datas.append(data)
                    unique_data_set.add(data_str)
                    x_n_list.append(torch.tensor(data['input_x_n']))
    num_categories = torch.cat(x_n_list).max(dim=0).values + 1
    dataset = LayerDAGEdgeRefineDataset(num_categories)
    for data in datas:
        dataset.add_data(data['input_src'],
                         data['input_dst'],
                         data['noisy_src'],
                         data['noisy_dst'],
                         data['noisy_src_old'],
                         data['noisy_dst_old'],
                         data['input_x_n'],
                         data['t'])
    return dataset

class LayerDAGEdgeRefineDataset(Dataset):
    def __init__(self, num_categories):
        super().__init__()
        self.src_list = []
        self.tgt_list = []
        self.noisy_list = []

        self.input_src_list = []
        self.input_dst_list = []
        self.noisy_src_old_list = []
        self.noisy_dst_old_list = []
        self.input_x_n_list = []
        self.noisy_src_list = []
        self.noisy_dst_list = []

    def __len__(self):
        return len(self.src_list)

    def __getitem__(self, index):
        return self.src_list[index], self.tgt_list[index], self.noisy_list[index]

    def graph_len(self):
        return len(self.input_x_n_list)

    def get_graph(self, index):
        return (self.input_src_list[index], self.input_dst_list[index], self.noisy_src_list[index], self.noisy_dst_list[index],
                self.noisy_src_old_list[index], self.noisy_dst_old_list[index], self.input_x_n_list[index])

    @staticmethod
    def construct_data(noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n):
        src_list = []
        tgt_list = []
        noisy_list = []
        assert set(noisy_dst) == set(noisy_dst_old)
        assert set(noisy_src).issubset(set(noisy_src_old))
        query_src_old_set = set(noisy_src_old)
        query_dst_old_set = set(noisy_dst_old)

        noise_edge = []
        noisy_old_edge = []
        total_query_src_list = []
        total_query_dst_list = []

        for src, dst in zip(noisy_src, noisy_dst):
            noise_edge.append((src, dst))

        for src, dst in zip(noisy_src_old, noisy_dst_old):
            noisy_old_edge.append((src, dst))

        for query_dst in list(sorted(list(query_dst_old_set))):
            label = []
            noisy_label = []
            query_src_list = []
            query_dst_list = []
            for query_src in list(sorted(list(query_src_old_set))):
                query_src_list.append(query_src)
                query_dst_list.append(query_dst)
                if (query_src, query_dst) not in noise_edge:
                    label.append(0)
                else:
                    label.append(1)
                if (query_src, query_dst) not in noisy_old_edge:
                    noisy_label.append(0)
                else:
                    noisy_label.append(1)

            src_list.append([input_x_n[i] for i in query_src_list].copy())
            tgt_list.append(label.copy())
            noisy_list.append(noisy_label.copy())
            total_query_src_list.append(query_src_list.copy())
            total_query_dst_list.append(query_dst_list.copy())
        return src_list, tgt_list, noisy_list, total_query_src_list, total_query_dst_list

    def add_data(self, input_src, input_dst, noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n, t):
        src_list, tgt_list, noisy_list, _, _ = self.construct_data(noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n)
        self.src_list.extend(src_list)
        self.tgt_list.extend(tgt_list)
        self.noisy_list.extend(noisy_list)

        self.input_src_list.append(input_src)
        self.input_dst_list.append(input_dst)
        self.noisy_src_list.append(noisy_src)
        self.noisy_dst_list.append(noisy_dst)
        self.noisy_src_old_list.append(noisy_src_old)
        self.noisy_dst_old_list.append(noisy_dst_old)
        self.input_x_n_list.append(input_x_n)


def collate_edge_refine(data):
    pad_to = 128
    batch_src = []
    batch_tgt = []
    batch_label = []
    def pad(seq):
        return seq + [2] * (pad_to - len(seq))
    (src, tgt, noisy) = map(list, zip(*data))
    for i in range(len(src)):
        src_ids = [3] + src[i] + [4] + noisy[i]
        tgt_ids = [3] + tgt[i][:-1]
        labels = tgt[i]
        batch_src.append(torch.tensor(pad(src_ids), dtype=torch.long))
        batch_tgt.append(torch.tensor(pad(tgt_ids), dtype=torch.long))
        batch_label.append(torch.tensor(pad(labels), dtype=torch.long))

    return {
        'src_ids': torch.stack(batch_src),
        'tgt_ids': torch.stack(batch_tgt),
        'labels': torch.stack(batch_label),
    }

def graph_check(pred_list, query_src_list, query_dst_list, tgt_list, noisy_list, input_src, input_dst, input_x_n, rho):
    pred_src = []
    pred_dst = []
    answer_src = []
    answer_dst = []
    noisy_src = []
    noisy_dst = []
    flag = False
    error = 0
    refine_correct = 0
    original_correct = 0
    answer_correct = 0
    for j in range(len(pred_list)):
        query_src = query_src_list[j]
        query_dst = query_dst_list[j]
        pred = pred_list[j]
        label = tgt_list[j]
        noisy = noisy_list[j]
        assert len(query_src) == len(query_dst) == len(pred) == len(label) == len(noisy)
        if sum(pred) == 0:
            pred[random.randint(0, len(pred) - 1)] = 1
            flag = True
        for k in range(len(query_src)):
            if pred[k] == 1:
                pred_src.append(query_src[k])
                pred_dst.append(query_dst[k])
        for k in range(len(query_src)):
            if label[k] == 1:
                answer_src.append(query_src[k])
                answer_dst.append(query_dst[k])
        for k in range(len(query_src)):
            if noisy[k] == 1:
                noisy_src.append(query_src[k])
                noisy_dst.append(query_dst[k])
    if flag:
        error = 1
    if not flag and check_predecessor_balance(src_list=input_src + pred_src, dst_list=input_dst + pred_dst,
                                              x_n_list=input_x_n, rho=rho):
        refine_correct = 1
    if check_predecessor_balance(src_list=input_src + noisy_src, dst_list=input_dst + noisy_dst, x_n_list=input_x_n,
                                 rho=rho):
        original_correct = 1
    if check_predecessor_balance(src_list=input_src + answer_src, dst_list=input_dst + answer_dst, x_n_list=input_x_n,
                                 rho=rho):
        answer_correct = 1
    return error, refine_correct, original_correct, answer_correct

@torch.no_grad()
def test_edge_refine(device, data_set, model, rho):
    model.eval()
    total = 0
    refine_correct = 0
    original_correct = 0
    error = 0
    answer_correct = 0


    batch_graph_size = 50
    data_size = data_set.graph_len() // batch_graph_size
    data_size = data_size
    for _ in tqdm(range(data_size)):
        src_list = []
        tgt_list = []
        noisy_list = []
        query_src_list = []
        query_dst_list = []
        split_list = []
        input_src_list = []
        input_dst_list = []
        noisy_src_list = []
        noisy_dst_list = []
        noisy_src_old_list = []
        noisy_dst_old_list = []
        input_x_n_list = []
        for _ in range(batch_graph_size):
            input_src, input_dst, noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n = data_set.get_graph(total)
            input_src_list.append(input_src)
            input_dst_list.append(input_dst)
            noisy_src_list.append(noisy_src)
            noisy_dst_list.append(noisy_dst)
            noisy_src_old_list.append(noisy_src_old)
            noisy_dst_old_list.append(noisy_dst_old)
            input_x_n_list.append(input_x_n)
            total += 1
            src, tgt, noisy, query_src, query_dst = LayerDAGEdgeRefineDataset.construct_data(noisy_src, noisy_dst, noisy_src_old, noisy_dst_old, input_x_n)
            src_list.extend(src)
            tgt_list.extend(tgt)
            noisy_list.extend(noisy)
            query_src_list.extend(query_src)
            query_dst_list.extend(query_dst)
            split_list.append(len(src))
        pred_list = generate_batch(src_list, noisy_list, model)
        idx = 0

        for i, length in enumerate(split_list):
            pred = pred_list[idx:idx + length]
            query_src = query_src_list[idx:idx + length]
            query_dst = query_dst_list[idx:idx + length]
            tgt = tgt_list[idx:idx + length]
            noisy = noisy_list[idx:idx + length]
            input_src = input_src_list[i]
            input_dst = input_dst_list[i]
            input_x_n = input_x_n_list[i]
            idx += length
            assert len(pred) == len(query_src) == len(query_dst) == len(tgt) == len(noisy)
            answer = graph_check(pred, query_src, query_dst, tgt, noisy, input_src, input_dst, input_x_n, rho=rho)
            error += answer[0]
            refine_correct += answer[1]
            original_correct += answer[2]
            answer_correct += answer[3]

    return refine_correct / total, original_correct / total, answer_correct / total, error / total

def main_edge_refine(device, train_set, val_set, test_set, model, batch_size, num_workers, num_epochs, patience, rho):
    print('Training...')
    train_loader = DataLoader(train_set,
                              shuffle=False,
                              collate_fn=collate_edge_refine,
                              batch_size=batch_size)
    criterion = nn.CrossEntropyLoss(ignore_index=2)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, amsgrad=True)
    best_correct = 0
    best_state_dict = deepcopy(model.state_dict())
    num_patient_epochs = 0
    VOCAB_SIZE = 5

    for epoch in range(num_epochs):
        refine_correct, original_correct, answer_correct, error_ratio = test_edge_refine(device, test_set, model, rho=rho)
        print(
            f'\033[32mtest accuracy is {refine_correct} with only {original_correct} before and ground truth is {answer_correct} with error ratio is {error_ratio}\033[0m')
        if refine_correct > best_correct and epoch > 5:
            best_state_dict = deepcopy(model.state_dict())
            best_correct = refine_correct
            num_patient_epochs = 0
        else:
            num_patient_epochs += 1
        if (patience is not None) and (num_patient_epochs == patience):
            break
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader):
            src = batch['src_ids'].transpose(0, 1).to(device) # (S,B)
            tgt = batch['tgt_ids'].transpose(0, 1).to(device)  # (T,B)
            lbl = batch['labels'].transpose(0, 1).to(device)  # (T,B)
            optimizer.zero_grad()
            logits = model(src, tgt)  # (T,B,V)
            loss = criterion(logits.view(-1, VOCAB_SIZE), lbl.reshape(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch:2d}  loss = {total_loss / len(train_loader):.4f}")
        # torch.save({
        #     'dataset': 'latent_preferential',
        #     'model_state_dict': model.state_dict()
        # }, './model.pth')
    return best_state_dict


@torch.no_grad()
def generate_batch(src_lists, noisy_lists, model):
    model.eval()
    device = next(model.parameters()).device

    B = len(src_lists)
    prefixes = [
        [3] + src + [4] + noisy
        for src, noisy in zip(src_lists, noisy_lists)
    ]
    lengths = [len(p) for p in prefixes]
    max_src_len = 128

    # 2) pad to (max_src_len, B)
    pad_id = 2
    src_ids = torch.full(
        (max_src_len, B),
        pad_id,
        dtype=torch.long,
        device=device
    )
    for i, p in enumerate(prefixes):
        src_ids[:len(p), i] = torch.tensor(p, dtype=torch.long, device=device)

    dec_inp = torch.full(
        (1, B),
        3,  # <sos>
        dtype=torch.long,
        device=device
    )
    gen_steps = max(len(src) for src in src_lists)

    for _ in range(gen_steps):
        logits = model(src_ids, dec_inp)   # (T, B, V)
        next_tok = logits[-1].argmax(dim=-1, keepdim=True)  # (B,1)
        dec_inp = torch.cat([dec_inp, next_tok.transpose(0, 1)], dim=0)     # (T+1, B)

    gens = dec_inp[1:]  # (gen_steps, B)
    results = []
    for i, src in enumerate(src_lists):
        L = len(src)
        seq = gens[:L, i].tolist()
        if sum(seq) == 0:
            seq[random.randint(0, len(seq) - 1)] = 1
        results.append(seq)
    return results


def main(args):
    torch.set_num_threads(args.num_threads)
    device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
    device = torch.device(device_str)
    set_seed(args.seed)

    train_files = glob.glob('./data/train_*.jsonl')
    val_files = glob.glob('./data/val_*.jsonl')
    if not (len(train_files) > 0 and len(val_files) > 0):
        print("No training or validation files found.")
        exit(0)
    print('loading latent preferential dataset')
    train_dataset = load_data(train_files)
    val_dataset = load_data(val_files)
    test_dataset = load_data(val_files)

    batch_size = 256
    num_workers = 4
    num_epochs = 30
    patience = 100

    model = EdgeRefineModelWithEdgeTransformer().to(device)
    save_path = './model/model_latent_preferential_smt_ar.pth'
    if not os.path.exists(save_path):
        edge_pred_state_dict = main_edge_refine(device, train_dataset, val_dataset, test_dataset,
                                                model, batch_size=batch_size, num_workers=num_workers,
                                                num_epochs=num_epochs, patience=patience, rho=args.rho)
        model.load_state_dict(edge_pred_state_dict)
        torch.save({
            'dataset': 'latent_preferential',
            'model_state_dict': model.state_dict()
        }, save_path)
    ckpt = torch.load(save_path, weights_only=True)
    model.load_state_dict(ckpt['model_state_dict'])
    refine_correct, original_correct, answer_correct, error_ratio = test_edge_refine(device, test_dataset, model, rho=args.rho)
    print(
        f'\033[32m==========  test accuracy is {refine_correct} with only {original_correct} before and ground truth is {answer_correct} with error ratio is {error_ratio} ==========\033[0m')


if __name__ == '__main__':
    from argparse import ArgumentParser
    start = time.perf_counter()
    parser = ArgumentParser()
    parser.add_argument("--config_file", type=str, default='configs/LayerDAG/latent_preferential.yaml')
    parser.add_argument("--num_threads", type=int, default=16)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--rho", type=float, default=0.5)
    args = parser.parse_args()
    print("start training smt model")
    main(args)
    end = time.perf_counter()
    print(f'use time: {end - start:.2f}s for train smt model')