import os
import os.path as osp
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import index_to_mask
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch_geometric.utils import from_networkx
from MegaGNN.datasets.temporal_dataset import TemporalDataset
from MegaGNN.graphgym.config import cfg
from typing import List
from MegaGNN.cpp import ports_cpp


def z_norm(data):
    std = data.std(0).unsqueeze(0)
    std = torch.where(std == 0, torch.tensor(1, dtype=torch.float32).cpu(), std)
    return (data - data.mean(0).unsqueeze(0)) / std


def find_parallel_edges(edge_index):
    simplified_edge_mapping = {}
    simplified_edge_batch = []
    i = 0
    for edge in edge_index.T:
        tuple_edge = tuple(edge.tolist())
        if tuple_edge not in simplified_edge_mapping:
            simplified_edge_mapping[tuple_edge] = i
            i += 1
        simplified_edge_batch.append(simplified_edge_mapping[tuple_edge])
    simplified_edge_batch = torch.LongTensor(simplified_edge_batch)

    return simplified_edge_batch


def ports_with_cpp(graph):    
    edge_index = graph['node', 'to', 'node'].edge_index
    timestamp = torch.ones(edge_index.shape[1])

    edges = torch.cat([edge_index.T, timestamp.reshape((-1,1))], dim=1).numpy().astype('int')

    ports_in, ports_out = ports_cpp.assign_ports(edges, edge_index.numpy().astype('int'), graph.num_nodes)
    
    return torch.from_numpy(ports_in), torch.from_numpy(ports_out)


class CybersecurityDataset(TemporalDataset):
    """
    Dataset class for the cybersecurity network data.
    Splits edges into train/val/test sets with 60/20/20 ratio.
    """
    csv_names = {
        'NF-BoT-IoT': 'NF-BoT-IoT.csv',
        'NF-ToN-IoT': 'NF-ToN-IoT.csv',
    }
    
    def __init__(
        self,
        root: str,
        name: str,
        reverse_mp: bool = False,
        add_ports: bool = False,
        multi_edge_agg: bool = False,
        transform = None,
        pre_transform = None
    ):
        self.name = name # Small-LI
        self.reverse_mp = reverse_mp
        self.add_ports = add_ports
        self.multi_edge_agg = multi_edge_agg
        self.ports_as_separate_nodes = cfg.dataset.ports_as_separate_nodes

        super().__init__(root, transform, pre_transform)
        self.data_dict = torch.load(self.processed_paths[0])

        if not reverse_mp:
            for split in ['train', 'val', 'test']:
                del self.data_dict[split]['node', 'rev_to', 'node']
            
        if add_ports:
            self.ports_dict = torch.load(self.processed_paths[1])
            for split in ['train', 'val', 'test']:
                self.data_dict[split] = self.add_ports_func(self.data_dict[split], self.ports_dict[split])

    def add_ports_func(self, data, ports):
        in_ports, out_ports = ports
        ports_arr = torch.stack([in_ports, out_ports], dim=1)
        data['node', 'to', 'node'].edge_attr = torch.cat([data['node', 'to', 'node'].edge_attr, ports_arr], dim=1)
        if self.reverse_mp:
            data['node', 'rev_to', 'node'].edge_attr = torch.cat([data['node', 'rev_to', 'node'].edge_attr, ports_arr[:, [1, 0]]], dim=1)
        return data
    
    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, 'processed')
    
    @property
    def raw_file_names(self) -> List[str]:
        return []
    
    @property
    def processed_file_names(self) -> list:
        return ['data.pt', 'ports.pt']

    def process(self):
        # Load the CSV data
        raw_file_path = osp.join(self.raw_dir, self.csv_names[self.name])
        data = pd.read_csv(raw_file_path)
        data = data.sample(frac=1, random_state=42).reset_index(drop=True)
        data['Label'] = 1-data['Label']
 
        # Preprocess the IP addresses and ports
        data['IPV4_SRC_ADDR'] = data.IPV4_SRC_ADDR.apply(str)
        data['L4_SRC_PORT'] = data.L4_SRC_PORT.apply(str)
        data['IPV4_DST_ADDR'] = data.IPV4_DST_ADDR.apply(str)
        data['L4_DST_PORT'] = data.L4_DST_PORT.apply(str)
        
        if self.ports_as_separate_nodes:
            data['IPV4_SRC_ADDR'] = data['IPV4_SRC_ADDR'] + ':' + data['L4_SRC_PORT']
            data['IPV4_DST_ADDR'] = data['IPV4_DST_ADDR'] + ':' + data['L4_DST_PORT']
        
        # Remove unneeded columns
        data.drop(columns=['L4_SRC_PORT', 'L4_DST_PORT'], inplace=True)
        
        if 'Attack' in data.columns:
            data.drop(columns=['Attack'], inplace=True)
        
        # Rename label column for consistency
        if 'Label' in data.columns:
            data.rename(columns={"Label": "label"}, inplace=True)
        
        # Step 1: Assign unique integers to source and destination IPs
        unique_src_ips = data['IPV4_SRC_ADDR'].unique()
        unique_dst_ips = data['IPV4_DST_ADDR'].unique()
        unique_ips = np.union1d(unique_src_ips, unique_dst_ips)
        ip_to_idx = {ip: idx for idx, ip in enumerate(unique_ips)}
        
        # Add from_id and to_id columns
        data['from_id'] = data['IPV4_SRC_ADDR'].map(ip_to_idx)
        data['to_id'] = data['IPV4_DST_ADDR'].map(ip_to_idx)
        
        # Step 2: Get edge feature columns
        edge_features = ['IN_BYTES', 'FLOW_DURATION_MILLISECONDS', 'OUT_PKTS', 
                         'IN_PKTS', 'OUT_BYTES', 'TCP_FLAGS', 'L7_PROTO', 'PROTOCOL']
        
        
        # Step 3: Create edge_index, edge_attr and node features
        max_n_id = max(data['from_id'].max(), data['to_id'].max()) + 1
        # Create a simple node feature (all ones)
        df_nodes = pd.DataFrame({'NodeID': np.arange(max_n_id), 'Feature': np.ones(max_n_id)})
        
        y = torch.tensor(data['label'].values, dtype=torch.long) if 'label' in data.columns else torch.zeros(len(data), dtype=torch.long)
        
        print(f"Number of nodes (unique IPs) = {df_nodes.shape[0]}")
        print(f"Number of flows = {data.shape[0]}")
        if 'label' in data.columns:
            print(f"Innocent Atack ratio = {y.sum()} / {len(y)} = {y.sum() / len(y) * 100:.2f}%")
        
        # Create node features, edge index and edge attributes
        node_features = ['Feature']
        x = torch.tensor(df_nodes[node_features].values, dtype=torch.float)
        edge_index = torch.tensor(data[['from_id', 'to_id']].values.T, dtype=torch.long)
        edge_attr = torch.tensor(data[edge_features].values, dtype=torch.float)
        
        print(f'Edge features being used: {edge_features}')
        print(f'Node features being used: {node_features} ("Feature" is a placeholder feature of all 1s)')
        
        # Step 4: Find parallel edges
        simp_edge_batch = find_parallel_edges(edge_index)
        
        # Step 5: Create train/val/test indices based on percentage splits
        num_edges = edge_index.size(1)
        indices = torch.arange(num_edges)
        
        # Simple percentage-based split: 60% train, 20% val, 20% test
        train_size = int(0.6 * num_edges)
        val_size = int(0.2 * num_edges)
        
        train_inds = indices[:train_size]
        val_inds = indices[train_size:train_size + val_size]
        test_inds = indices[train_size + val_size:]
        
        # For creating masked graphs
        e_train = train_inds
        e_val = torch.cat([train_inds, val_inds])
        e_test = torch.cat([train_inds, val_inds, test_inds])
        
        # Step 6: Create data dictionaries for different splits
        self.ports_dict = {}
        self.data_dict = {}
        
        for split in ['train', 'val', 'test']:
            inds = eval(f'{split}_inds')
            e_mask = eval(f'e_{split}')
            
            masked_edge_index = edge_index[:, e_mask]
            masked_edge_attr = z_norm(edge_attr[e_mask])
            masked_y = y[e_mask]
            
            data = HeteroData()
            data['node'].x = z_norm(x)
            data['node'].num_nodes = int(x.shape[0])
            data['node', 'to', 'node'].edge_index = masked_edge_index
            data['node', 'to', 'node'].edge_attr = masked_edge_attr
            data['node', 'to', 'node'].y = masked_y
            
            data['node', 'rev_to', 'node'].edge_index = masked_edge_index.flipud()
            data['node', 'rev_to', 'node'].edge_attr = masked_edge_attr
            
            # For megagnn
            masked_simp_edge_batch = simp_edge_batch[e_mask]
            data['node', 'to', 'node'].simp_edge_batch = masked_simp_edge_batch
            data['node', 'rev_to', 'node'].simp_edge_batch = masked_simp_edge_batch
            
            # Define the labels in the training/validation/test sets
            data['node', 'to', 'node'].split_mask = index_to_mask(inds, size=masked_edge_index.shape[1])
            
            in_ports, out_ports = ports_with_cpp(data)
            self.ports_dict[split] = [in_ports, out_ports]
            self.data_dict[split] = data
        
        if self.pre_transform is not None:
            data = self.pre_transform(data)
        
        torch.save(self.data_dict, self.processed_paths[0])
        torch.save(self.ports_dict, self.processed_paths[1])
    
    def __repr__(self) -> str:
        return f'CybersecurityDataset(name={self.name})' 