"""
Create GearNet epiformer graph dataset with 7 edge relation types.

GearNet Graph Type:
- Epiformer with atom, residue, and edge graphs  
- Residue graph has 7 edge types/relations following GearNet paper
- Features: geometric + PLM embeddings
- Output: gearnet_epiformer_dataset.pkl

This follows the same pattern as graph_construction.py but creates GearNet-style residue graphs.
"""

import os
import sys
import torch
import numpy as np
from typing import Optional

# Add paths for imports
sys.path.append(os.path.join(os.getcwd(), '../../'))


from construct_epiformer_graphs import EpiformerDatasetCreator, ResidueGraphBuilder


class GearNetEpiformerDatasetCreator(EpiformerDatasetCreator):
    """Creates GearNet epiformer dataset with 7 edge relation types"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Override residue builder for GearNet graphs
        self.residue_builder = GearNetResidueGraphBuilder()


class GearNetResidueGraphBuilder(ResidueGraphBuilder):
    """Build GearNet residue graphs with 7 relation types"""
    
    def build_residue_graph(self, pdb_path: str):
        """Build GearNet residue graph with 7 edge relation types"""
        # Use parent method as base
        residue_data = super().build_residue_graph(pdb_path)
        
        # Replace with GearNet 7-relation structure
        if residue_data.pos is not None:
            ca_positions = residue_data.pos
            
            # Build GearNet-style 7 relation types
            gearnet_edge_lists = self._build_gearnet_relations(ca_positions, residue_data)
            
            # Create edge features for all relations
            gearnet_edge_features = {}
            for rel_type, edges in gearnet_edge_lists.items():
                for src, dst in edges:
                    edge_key = (src, dst, rel_type)
                    gearnet_edge_features[edge_key] = self._create_edge_features(
                        src, dst, ca_positions, residue_data
                    )
            
            # Replace original structure with GearNet structure
            residue_data.edge_lists = gearnet_edge_lists
            residue_data.edge_features = gearnet_edge_features
            
            # Also update PyG format edges (combine all relation types)
            all_edges = []
            all_edge_attrs = []
            
            for rel_type, edges in gearnet_edge_lists.items():
                for edge in edges:
                    all_edges.append(edge)
                    edge_key = (edge[0], edge[1], rel_type)
                    all_edge_attrs.append(gearnet_edge_features[edge_key])
            
            if all_edges:
                residue_data.edge_index = torch.tensor(all_edges).T.long()
                residue_data.edge_attr = torch.stack(all_edge_attrs)
            else:
                residue_data.edge_index = torch.zeros((2, 0), dtype=torch.long)
                residue_data.edge_attr = torch.zeros((0, 100))
        
        return residue_data
    
    def _build_gearnet_relations(self, ca_positions, residue_data):
        """Build GearNet 7 relation types"""
        n_residues = ca_positions.shape[0]
        ca_coords = ca_positions
        
        # Initialize edge lists for 7 relation types
        edge_lists = {r: [] for r in range(7)}
        
        # Get sequence information if available
        sequence_info = getattr(residue_data, 'sequence_info', None)
        
        # Compute distance matrix
        dist_matrix = torch.cdist(ca_coords, ca_coords)
        
        # Relations 0-1: Sequential edges
        for i in range(n_residues):
            for j in range(n_residues):
                if i == j:
                    continue
                
                # Assume sequential if within reasonable distance
                seq_dist = abs(i - j)  # Simple approximation
                
                if seq_dist == 1:
                    edge_lists[0].append([i, j])  # sequential_1
                elif seq_dist == 2:
                    edge_lists[1].append([i, j])  # sequential_2
        
        # Relation 2: k-NN spatial neighbors  
        k_nn = 10
        for i in range(n_residues):
            distances = dist_matrix[i]
            _, indices = torch.topk(-distances, k=min(k_nn + 1, n_residues))
            
            for j in indices[1:k_nn + 1]:  # Skip self
                j_idx = j.item()
                if j_idx != i:
                    # Check not already connected by sequential edges
                    if not any([i, j_idx] in edge_lists[r] for r in range(2)):
                        edge_lists[2].append([i, j_idx])
        
        # Relation 3: Spatial cutoff neighbors
        spatial_cutoff = 10.0
        spatial_edges = torch.nonzero((dist_matrix < spatial_cutoff) & (dist_matrix > 0)).tolist()
        for i, j in spatial_edges:
            # Check not already connected
            if not any([i, j] in edge_lists[r] for r in range(3)):
                edge_lists[3].append([i, j])
        
        # Relation 4: Covalent bonds (backbone)
        for i in range(n_residues - 1):
            if i + 1 < n_residues:
                edge_lists[4].append([i, i + 1])
                edge_lists[4].append([i + 1, i])  # Both directions
        
        # Relations 5-6: Angle bins (simplified)
        for i in range(n_residues):
            for j in range(i + 2, min(i + 6, n_residues)):
                dist = dist_matrix[i, j].item()
                
                if 5.0 < dist < 8.0:  # Angle bin 0
                    edge_lists[5].append([i, j])
                elif 8.0 <= dist < 10.0:  # Angle bin 1
                    edge_lists[6].append([i, j])
        
        return edge_lists
    
    def _create_edge_features(self, src, dst, ca_coords, residue_data):
        """Create 100D edge features for GearNet"""
        edge_feat = torch.zeros(100)
        
        # Distance (1D)
        distance = torch.norm(ca_coords[src] - ca_coords[dst])
        edge_feat[0] = distance
        
        # RBF encoding (16D)
        rbf_centers = torch.linspace(0, 20, 16)
        rbf_features = torch.exp(-0.5 * ((distance - rbf_centers) / 1.0) ** 2)
        edge_feat[1:17] = rbf_features
        
        # Direction vector (3D)
        direction = ca_coords[dst] - ca_coords[src]
        direction = direction / (torch.norm(direction) + 1e-8)
        edge_feat[17:20] = direction
        
        # Additional features (80D) - placeholder for now
        # These could include angles, dihedrals, AA types, etc.
        
        return edge_feat


def create_gearnet_epiformer_dataset(asep_data_path: str,
                                       ag_pdb_dir: str,
                                       ab_pdb_dir: str,
                                       output_path: str,
                                       max_examples: Optional[int] = None):
    """Create GearNet epiformer dataset"""
    
    # Create output directory
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Set antiberty path following graph_construction.py pattern
    antiberty_path = os.path.join(os.path.dirname(output_path), "../../../data/asep/antibody/antiberty_embeddings/asep_antiberty_embeddings.pt")
    
    # Initialize GearNet dataset creator
    dataset_creator = GearNetEpiformerDatasetCreator(
        asep_data_path=asep_data_path,
        ag_pdb_dir=ag_pdb_dir,
        ab_pdb_dir=ab_pdb_dir,
        output_path=output_path,
        antiberty_path=antiberty_path
    )
    
    # Create dataset
    print("Creating GearNet epiformer dataset...")
    dataset = dataset_creator.create_dataset(max_examples=max_examples)
    
    print(f"Successfully created GearNet epiformer dataset with {len(dataset)} complexes")
    return dataset


if __name__ == "__main__":

    import argparse

    parser = argparse.ArgumentParser(description='Construct herarchical gearnet graph dataset')
    # parser.add_argument('--asep_data', type=str, required=True,
    #                    help='Path to AsEP preprocessed data file')
    # parser.add_argument('--ag_pdb_dir', type=str, required=True,
    #                    help='Directory containing antigen PDB files')
    # parser.add_argument('--ab_pdb_dir', type=str, required=True,
    #                    help='Directory containing antibody PDB files')
    parser.add_argument('--filename', type=str, 
                       default='gearnet_epiformer_dataset.pkl',
                       help='dataset filename')
    parser.add_argument('--output_path', type=str, 
                       default='../../../data/asep/gearnet_epiformer_dataset.pkl',
                       help='Output path for the dataset')
    parser.add_argument('--max_examples', type=int, default=None,
                       help='Maximum number of examples to process')
    parser.add_argument('--log_level', type=str, default='INFO',
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'])
    
    args = parser.parse_args()

    import sys
    sys.path.append(os.path.join(os.getcwd(), 'm3epi'))
    sys.path.append(os.path.join(os.getcwd(), '../../'))
    sys.path.append(os.path.join(os.getcwd(), '../../../walle'))


    # Default configuration for standalone execution
    proj_dir = os.path.join(os.getcwd(), '../../../../')
    asep_data_path = os.path.join(proj_dir, "data/asep/processed/dict_pre_cal_esm2_esm2.pt")
    ag_pdb_dir = os.path.join(proj_dir, "data/asep/antigen/atmseq2surf")
    ab_pdb_dir = os.path.join(proj_dir, "data/asep/antibody/atmseq2cdr")
    filename = "gearnet_epiformer_dataset.pkl"
    output_path = os.path.join(proj_dir, "data/asep/m3epi/", args.filename)
    
    
    dataset = create_gearnet_epiformer_dataset(
        asep_data_path=asep_data_path,
        ag_pdb_dir=ag_pdb_dir,
        ab_pdb_dir=ab_pdb_dir,
        output_path=output_path,
        max_examples=args.max_examples
    )


"""
example usage:

python data/construct_gearnet_epiformer_graphs.py \
    --max_examples 5 \
    --filename "gearnet_epiformer_dataset_test.pkl"
"""