"""
Create base graph dataset for epiformer encoder experiments.

Base Graph Type:
- Simple residue-only graph (no hierarchy)
- Proximity-based edges (4.5Å threshold)
- Features: geometric + PLM embeddings
- Output: base_dataset.pkl

This serves as a baseline to compare against epiformer approaches.
"""

import os
import sys
import torch
from typing import Optional
from torch_geometric.data import HeteroData
import logging

# Add paths for imports
sys.path.append(os.path.join(os.getcwd(), '../../'))

# Import from existing epiformer dataset creator
from construct_epiformer_graphs import EpiformerDatasetCreator

logger = logging.getLogger(__name__)


class BaseDatasetCreator(EpiformerDatasetCreator):
    """Creates base graphs (residue-only, no hierarchy) using HeteroData structure"""
    
    def create_dataset(self, max_examples=None):
        """Override to create base graphs instead of epiformer graphs"""
        
        # Load AsEP data
        asep_graphs = self.load_asep_data()
        
        # Load antiberty embeddings if available
        antiberty_embeddings = {}
        if self.antiberty_path and os.path.exists(self.antiberty_path):
            try:
                antiberty_embeddings = torch.load(self.antiberty_path)
                print(f"Loaded AntiBERTy embeddings for {len(antiberty_embeddings)} complexes")
            except Exception as e:
                print(f"Warning: Could not load AntiBERTy embeddings: {e}")
        
        # Get complex IDs
        complex_ids = list(asep_graphs.keys())
        # Remove problematic complexes
        if "5nj6_0P" in complex_ids:
            complex_ids.remove("5nj6_0P")
        if "5ies_0P" in complex_ids:
            complex_ids.remove("5ies_0P")
        
        if max_examples:
            complex_ids = complex_ids[:max_examples]
        
        print(f"Processing {len(complex_ids)} complexes for base dataset...")
        
        dataset = []
        successful_count = 0
        
        for i, complex_id in enumerate(complex_ids):
            print(f"Processing {i+1}/{len(complex_ids)}: {complex_id}")
            
            try:
                # Create base graph for this complex
                base_data = self._create_base_graph(
                    complex_id, asep_graphs, antiberty_embeddings
                )
                
                if base_data is not None:
                    dataset.append(base_data)
                    successful_count += 1
                    
            except Exception as e:
                print(f"Error processing {complex_id}: {e}")
                continue
        
        print(f"Successfully processed: {successful_count}/{len(complex_ids)} complexes")
        
        # Save dataset
        print(f"Saving base dataset to {self.output_path}")
        torch.save(dataset, self.output_path)
        
        return dataset
    
    def _create_base_graph(self, complex_id, asep_graphs, antiberty_embeddings):
        """Create base graph with residue-only structure using HeteroData"""
        
        # Get AsEP data
        asep_data = asep_graphs[complex_id]
        
        # Extract PLM embeddings
        plm_data = self.extract_plm_embeddings(asep_graphs, complex_id, antiberty_embeddings)
        if plm_data is None:
            return None
        
        # Extract labels
        labels_data = self.extract_labels_and_interactions(asep_graphs, complex_id)
        
        # Get PDB paths
        ag_pdb_path = os.path.join(self.ag_pdb_dir, f"{complex_id}_surf.pdb")
        ab_pdb_path = os.path.join(self.ab_pdb_dir, f"{complex_id}_cdr.pdb")
        
        if not os.path.exists(ag_pdb_path) or not os.path.exists(ab_pdb_path):
            print(f"Warning: PDB files not found for {complex_id}")
            return None
        
        # Build antigen and antibody residue graphs with 4.5Å proximity edges
        ag_residue_data = self._build_proximity_residue_graph(ag_pdb_path, plm_data['antigen'])
        ab_residue_data = self._build_proximity_residue_graph(ab_pdb_path, plm_data['antibody'])
        
        # Create base HeteroData with only residue nodes and simple edges
        hetero_data = HeteroData()
        hetero_data.complex_id = complex_id
        
        # Add antigen residue data (combine all edge types into simple edges)
        hetero_data['ag_res'].x = ag_residue_data.x
        hetero_data['ag_res'].plm = ag_residue_data.plm
        hetero_data['ag_res'].pos = ag_residue_data.pos
        hetero_data['ag_res'].y = labels_data['epitope_labels']
        
        # Add antibody residue data
        hetero_data['ab_res'].x = ab_residue_data.x
        hetero_data['ab_res'].plm = ab_residue_data.plm
        hetero_data['ab_res'].pos = ab_residue_data.pos
        hetero_data['ab_res'].y = labels_data['paratope_labels']
        
        # Add simple residue-residue edges (combine all relation types)
        self._add_simple_edges(hetero_data, ag_residue_data, ab_residue_data)
        
        # Add interaction edges
        if "edge_index_bg" in asep_data:
            hetero_data['ag_res', 'interacts', 'ab_res'].edge_index = asep_data["edge_index_bg"]
        else:
            hetero_data['ag_res', 'interacts', 'ab_res'].edge_index = torch.zeros((2, 0), dtype=torch.long)
        
        return hetero_data
    
    def _build_proximity_residue_graph(self, pdb_path: str, plm_embeddings: torch.Tensor):
        """Build residue graph with 4.5Å proximity edges only"""
        # Get basic residue data using parent's residue builder
        base_residue_data = self.residue_builder.build_residue_graph(pdb_path)
        
        # Override with 4.5Å proximity edges only
        if base_residue_data.pos is not None:
            ca_positions = base_residue_data.pos
            proximity_edges, edge_features = self._create_proximity_edges(ca_positions)
            
            # Replace with simple edge structure
            base_residue_data.edge_index = proximity_edges
            base_residue_data.edge_attr = edge_features
            
            # Add PLM embeddings
            base_residue_data.plm = plm_embeddings
        
        return base_residue_data
    
    def _create_proximity_edges(self, ca_positions, threshold=4.5):
        """Create edges based on 4.5Å proximity threshold only"""
        # Compute distance matrix
        dist_matrix = torch.cdist(ca_positions, ca_positions)
        
        # Find residue pairs within threshold (excluding self-loops)
        edge_mask = (dist_matrix < threshold) & (dist_matrix > 0)
        src_nodes, dst_nodes = torch.nonzero(edge_mask, as_tuple=True)
        
        if len(src_nodes) == 0:
            # No edges within threshold
            return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 100))
        
        # Create edge index
        edge_index = torch.stack([src_nodes, dst_nodes], dim=0)
        
        # Create edge features (100D to match RAAD format)
        distances = dist_matrix[src_nodes, dst_nodes]
        edge_features = []
        
        for i, (src, dst) in enumerate(zip(src_nodes, dst_nodes)):
            edge_feat = torch.zeros(100)
            
            # Distance (1D)
            distance = distances[i]
            edge_feat[0] = distance
            
            # RBF encoding (16D)
            rbf_centers = torch.linspace(0, 10, 16)
            rbf_features = torch.exp(-0.5 * ((distance - rbf_centers) / 1.0) ** 2)
            edge_feat[1:17] = rbf_features
            
            # Direction vector (3D)
            direction = ca_positions[dst] - ca_positions[src]
            direction = direction / (torch.norm(direction) + 1e-8)
            edge_feat[17:20] = direction
            
            # Proximity type indicator (1D) - mark as spatial proximity
            edge_feat[20] = 1.0  # Proximity edge type
            
            # Rest remains zero (placeholder for additional features)
            edge_features.append(edge_feat)
        
        edge_attr = torch.stack(edge_features)
        return edge_index, edge_attr
    
    def _add_simple_edges(self, hetero_data, ag_residue_data, ab_residue_data):
        """Add simple edges by combining all relation types"""
        
        # For antigen: combine all edge types from multi-relational structure
        ag_edges = []
        ag_edge_attrs = []
        
        if hasattr(ag_residue_data, 'edge_lists'):
            for rel_type, edge_list in ag_residue_data.edge_lists.items():
                if edge_list:  # If edges exist
                    ag_edges.extend(edge_list)
                    # Get edge features
                    for src, dst in edge_list:
                        edge_key = (src, dst, rel_type)
                        if edge_key in ag_residue_data.edge_features:
                            ag_edge_attrs.append(ag_residue_data.edge_features[edge_key])
        
        # Set antigen edges
        if ag_edges:
            hetero_data['ag_res', 'connects', 'ag_res'].edge_index = torch.tensor(ag_edges).T.long()
            hetero_data['ag_res', 'connects', 'ag_res'].edge_attr = torch.stack(ag_edge_attrs)
        else:
            hetero_data['ag_res', 'connects', 'ag_res'].edge_index = torch.zeros((2, 0), dtype=torch.long)
            hetero_data['ag_res', 'connects', 'ag_res'].edge_attr = torch.zeros((0, 100))
        
        # For antibody: same process
        ab_edges = []
        ab_edge_attrs = []
        
        if hasattr(ab_residue_data, 'edge_lists'):
            for rel_type, edge_list in ab_residue_data.edge_lists.items():
                if edge_list:
                    ab_edges.extend(edge_list)
                    for src, dst in edge_list:
                        edge_key = (src, dst, rel_type)
                        if edge_key in ab_residue_data.edge_features:
                            ab_edge_attrs.append(ab_residue_data.edge_features[edge_key])
        
        # Set antibody edges
        if ab_edges:
            hetero_data['ab_res', 'connects', 'ab_res'].edge_index = torch.tensor(ab_edges).T.long()
            hetero_data['ab_res', 'connects', 'ab_res'].edge_attr = torch.stack(ab_edge_attrs)
        else:
            hetero_data['ab_res', 'connects', 'ab_res'].edge_index = torch.zeros((2, 0), dtype=torch.long)
            hetero_data['ab_res', 'connects', 'ab_res'].edge_attr = torch.zeros((0, 100))


def create_base_dataset(asep_data_path: str,
                       ag_pdb_dir: str,
                       ab_pdb_dir: str,
                       output_path: str,
                       max_examples: Optional[int] = None):
    """Create base dataset using HeteroData structure"""
    
    # Create output directory
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Set antiberty path following pattern from other dataset constructors
    antiberty_path = os.path.join(os.path.dirname(output_path), "../../../data/asep/antibody/antiberty_embeddings/asep_antiberty_embeddings.pt")
    
    # Initialize base dataset creator
    dataset_creator = BaseDatasetCreator(
        asep_data_path=asep_data_path,
        ag_pdb_dir=ag_pdb_dir,
        ab_pdb_dir=ab_pdb_dir,
        output_path=output_path,
        antiberty_path=antiberty_path
    )
    
    # Create dataset
    print("Creating base dataset...")
    dataset = dataset_creator.create_dataset(max_examples=max_examples)
    
    print(f"Successfully created base dataset with {len(dataset)} complexes")
    return dataset



if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser(description='Construct base graph dataset')
    # parser.add_argument('--asep_data', type=str, required=True,
    #                    help='Path to AsEP preprocessed data file')
    # parser.add_argument('--ag_pdb_dir', type=str, required=True,
    #                    help='Directory containing antigen PDB files')
    # parser.add_argument('--ab_pdb_dir', type=str, required=True,
    #                    help='Directory containing antibody PDB files')
    parser.add_argument('--filename', type=str, 
                       default='base_dataset.pkl',
                       help='dataset filename')
    parser.add_argument('--output_path', type=str, 
                       default='../../../data/asep/base_dataset.pkl',
                       help='Output path for base dataset')
    parser.add_argument('--max_examples', type=int, default=None,
                       help='Maximum number of examples to process')
    parser.add_argument('--log_level', type=str, default='INFO',
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'])
    
    args = parser.parse_args()

    import sys
    sys.path.append(os.path.join(os.getcwd(), 'm3epi'))
    sys.path.append(os.path.join(os.getcwd(), '../../'))
    sys.path.append(os.path.join(os.getcwd(), '../../../walle'))


    # Default configuration for standalone execution
    proj_dir = os.path.join(os.getcwd(), '../../../../')
    asep_data_path = os.path.join(proj_dir, "data/asep/processed/dict_pre_cal_esm2_esm2.pt")
    ag_pdb_dir = os.path.join(proj_dir, "data/asep/antigen/atmseq2surf")
    ab_pdb_dir = os.path.join(proj_dir, "data/asep/antibody/atmseq2cdr")
    filename = "base_dataset.pkl"
    output_path = os.path.join(proj_dir, "data/asep/m3epi/", args.filename)
    
    dataset = create_base_dataset(
        asep_data_path=asep_data_path,
        ag_pdb_dir=ag_pdb_dir,
        ab_pdb_dir=ab_pdb_dir,
        output_path=output_path,
        max_examples=args.max_examples
    )


"""
example usage:

python data/construct_base_graphs.py \
    --max_examples 5 \
    --filename "base_dataset_test.pkl"  
"""

