import os
import json
import argparse
import logging
import time
import numpy as np
import torch
import dgl
from torch.utils.data import Dataset, RandomSampler, DataLoader
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
from gensim.models import Word2Vec
from utils.cpp_tokenizer import tokenize_c
from utils.java_tokenizer import tokenize_java
from tqdm import tqdm

# logger
logger = logging.getLogger(__name__)


ALL_JAVA_NODES_LABELS = {
    'OTHERS', 'BINDING', 'CALL', 'METHOD_PARAMETER_OUT', 'TYPE_DECL', 'ANNOTATION', 'CONTROL_STRUCTURE'
    , 'RETURN', 'META_DATA', 'MEMBER', 'METHOD_PARAMETER_IN', 'METHOD'
}

ALL_C_NODES_LABELS = {
    'IMPORT', 'CONTROL_STRUCTURE', 'MEMBER', 'JUMP_TARGET', 'TYPE_REF', 'META_DATA', 'METHOD', 'RETURN', 
    'OTHERS', 'TYPE_DECL', 'CALL', 'METHOD_PARAMETER_OUT', 'METHOD_REF', 'BINDING', 'METHOD_PARAMETER_IN'
}

ALL_EDGE_LABELS = {
    'AST', 'CDG', 'CALL', 'CFG', 'NATURE_SEQUENCE','OTHERS'
}

class IntegratedDGLDataset(Dataset):
    """
        IntegratedDGLDataset
    """
    def __init__(self, args, input_file, sample_percent=1.0):
        self.args = args
        self.examples = []
        self.num_nodes = self.args.num_nodes
        self.encoder_type = self.args.encoder_type

        # load Word2Vec
        self.word2vec_model = Word2Vec.load(args.word2vec_path)
        
        self.block_size = self.args.block_size
        self.feature_dim = self.args.feature_dim_size
        
        # build token index
        if self.args.lang == 'c':
            self.node_token_index = {label: index for index, label in enumerate(ALL_C_NODES_LABELS)}
        elif self.args.lang == 'java':
            self.node_token_index = {label: index for index, label in enumerate(ALL_JAVA_NODES_LABELS)}
        self.edge_token_index = {label: index for index, label in enumerate(ALL_EDGE_LABELS)}

        # process all samples
        with open(input_file, 'r') as f:
            lines = f.readlines()
            total_len = len(lines)
            num_keep = int(sample_percent * total_len)
            
            if num_keep < total_len:
                # np.random.seed(args.seed)
                indices = np.random.permutation(total_len)[:num_keep]
                lines = [lines[i] for i in indices]
                
            # processing
            for line in tqdm(lines, desc="Process sample"):
                sample = json.loads(line)
                processed = self.process_sample(sample)
                self.examples.append(processed)
                # try:
                #     sample = json.loads(line)
                #     processed = self.process_sample(sample)
                #     self.examples.append(processed)
                # except Exception as e:
                #     logger.error(f"An error occurred when processing the sample: {e}")
        
        logger.info("*** Dataset statistics ***")
        logger.info(f"\tTotal sample size: {total_len}")
        logger.info(f"\tThe number of processed samples: {len(self.examples)}")
        
        if self.examples:
            logger.info("*** Sample example ***")
            example = self.examples[0]
            logger.info(f"\tNumber of nodes: {len(example['node_features'])}")
            logger.info(f"\tNumber of edges: {len(example['edge_types'])}")
            logger.info(f"\tedge example: src={example['edges'][0][:5]}... dst={example['edges'][1][:5]}...")

    def __len__(self):
        return len(self.examples)

    def convert_code_to_token(self, code):
        """Convert the code into a token ID sequence and perform padding processing"""
        code = ' '.join(code.split())
        code_tokens = self.tokenizer.tokenize(code)[:self.block_size-2]
        source_tokens = [self.tokenizer.cls_token] + code_tokens + [self.tokenizer.sep_token]
        source_ids = self.tokenizer.convert_tokens_to_ids(source_tokens)
        padding_length = self.block_size - len(source_ids)
        return source_ids + [self.tokenizer.pad_token_id] * padding_length

    def preprocess_edges(self, edges, num_nodes):
        """Preprocess the edge data to ensure that the node index is within the valid range"""
        src, dst = [], []
        for s, d in edges:
            if s < num_nodes and d < num_nodes:
                src.append(s)
                dst.append(d)
        return [src, dst]
    
    def padding_cids(self, data):
        """Feature matrix preprocessing and filling"""
        pad = self.num_nodes - len(data)
        padded = np.pad(data, ((0, pad), (0, 0)), mode='constant')
        return padded

    def padding_types(self, data):
        """Feature matrix preprocessing and filling"""
        pad = self.num_nodes - len(data)
        padded = np.pad(data, ((0, pad)), mode='constant')
        return padded

    def padding_feat(self, node_features):
        """Fill the node features to a fixed size: [self.num_nodes, 129]"""
        current_num = len(node_features)

        if current_num >= self.num_nodes:
            return np.array(node_features[:self.num_nodes], dtype=np.float32)

        padding_features = [[0.0] * 129 for _ in range(self.num_nodes - current_num)]

        padded_features = np.array(node_features + padding_features, dtype=np.float32)
        
        return padded_features

    def process_sample(self, sample):
        """Process a single sample"""
        edge_types = []
        nodes = sample["nodes"]
        edges = sample["edges"]
        nodes_codes = sample["nodes_codes"]
        node_labels = sample["nodes_label"]
        edge_labels = sample["edges_label"]
        node_features = []
        
        num_nodes = min(len(nodes), self.num_nodes) 
        edges = self.preprocess_edges(edges, num_nodes)
        num_edges = len(edges[0])
        
        for i in range(num_nodes):
            # Type embedding
            node_label = node_labels[i]
            if node_label in self.node_token_index:
                node_type = self.node_token_index[node_label]
            else:
                node_type = self.node_token_index["OTHERS"]
            type_emb = torch.tensor(node_type).unsqueeze(0)
            
            emb_seq = []
            if self.args.lang == 'c':
                code_tokens = tokenize_c(nodes_codes[i])
            elif self.args.lang == 'java':
                code_tokens = tokenize_java(nodes_codes[i])
                
            if len(code_tokens)==0:
                feature_out = torch.tensor(np.zeros(128))
            else:
                for token in code_tokens:
                    try:
                        emb_seq.append(self.word2vec_model.wv[token])
                    except:
                        emb_seq.append(np.zeros(self.args.feature_dim_size-1))
                emb_seq = torch.tensor(np.array(emb_seq))
                feature_out = torch.sum(emb_seq, 0)
            feature_out = torch.cat((type_emb, feature_out), 0)
            feature_out = feature_out.cpu().detach().numpy().tolist()
            node_features.append(feature_out)

        # process edge
        for i in range(num_edges):
            edge_label = edge_labels[i]
            if edge_label in self.edge_token_index:
                edge_type = self.edge_token_index[edge_label]
            else:
                edge_type = self.edge_token_index["OTHERS"]
            edge_types.append(edge_type)
        
        return {
            "node_features": node_features,
            "edges": edges,
            "edge_types": edge_types,
            "target": sample.get("target", 0)
        }
    
    def postprocess_graph(self, g):
        """Add direction feature encoding"""
        edge_types = g.edata['edge_types']
        
        src, dst = g.edges()
        has_fwd = g.has_edges_between(src, dst).float()
        has_bwd = g.has_edges_between(dst, src).float()
        edge_dir = has_fwd - has_bwd
        
        # Add edge features
        g.edata['edge_direction'] = edge_dir.unsqueeze(-1)
        g.edata['edge_feature'] = torch.cat([
            edge_dir.unsqueeze(-1), 
            edge_types.unsqueeze(-1)
        ], dim=-1)
        
        return g

    def padding_feat_tensor(self, node_features):
        """Fill the node features to a fixed size [self.num_nodes, 129]"""
        current_num = len(node_features)
        padding_num = self.num_nodes - current_num
        
        if padding_num <= 0:
            return torch.tensor(node_features, dtype=torch.float32)
        
        padding_features = [[0.0] * 129 for _ in range(padding_num)]
        
        padded_features = node_features + padding_features
        return torch.tensor(padded_features, dtype=torch.float32)

    def __getitem__(self, idx):
        example = self.examples[idx]
        example_node_features = example['node_features']
        example_edges = example['edges']
        example_edge_types = example['edge_types']
        example_target = int(example['target'])
        
        assert len(example_edges[0]) == len(example_edge_types), \
            f"The number of edges ({len(example_edges[0])}) does not match the number of edge features ({len(example_edge_types)})!"
        # build DGL
        src_nodes = torch.tensor(example_edges[0], dtype=torch.long)
        dst_nodes = torch.tensor(example_edges[1], dtype=torch.long)
        g = dgl.graph((src_nodes, dst_nodes))

        current_num_nodes = g.num_nodes()
        desired_num_nodes = self.num_nodes
        
        if current_num_nodes < desired_num_nodes:
            g.add_nodes(desired_num_nodes - current_num_nodes)
            
        features_tensor = self.padding_feat_tensor(example_node_features)
        g.ndata['feat'] = features_tensor
        
        g.edata['edge_types'] = torch.tensor(example_edge_types, dtype=torch.long)
        
        if self.args.directed:
            g = self.postprocess_graph(g)

        # Graph-level label
        graph_target = torch.tensor(example_target, dtype=torch.long)
        
        return g, graph_target