import os
import warnings
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, Linear, to_hetero
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import softmax
from torch_scatter import scatter_softmax,scatter
from function_call_agent.graph_training.model import LinkPredictionModel

warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def data_preparation(api_df,params_df,model):
    
    # np.log
    columns_to_log = ['in_degree', 'out_degree', 'succ_cnt']
    columns_to_keep = ['succ_rate']
    
    api_df['feature_list'] = api_df.apply(lambda row: [np.log(row[col])+0.01 if row[col] != 0 else 0 for col in columns_to_log] + [row[col] for col in columns_to_keep], axis=1)
    params_df['feature_list'] = params_df.apply(lambda row: [np.log(row[col])+0.01 if row[col] != 0 else 0 for col in columns_to_log] + [row[col] for col in columns_to_keep], axis=1)
    
    # add emb
    api_df['description_emb'] = api_df.apply(lambda x: (model.encode([x['description']], normalize_embeddings=True)[0]).tolist() + x['feature_list'],axis=1)
    api_df['index'] = api_df.index

    params_df['description_emb'] = params_df.apply(lambda x: (model.encode([x['description_last']], normalize_embeddings=True)[0]).tolist() + x['feature_list'],axis=1)
    params_df['index'] = params_df.index

    #add idx
    api_idx = api_df.set_index(['api_name_'])['index'].to_dict()
    param_idx = params_df.set_index(['params_name_last'])['index'].to_dict()

    return api_df,params_df,api_idx,param_idx


def data_insert_graph(api_df,params_df,api_idx,param_idx,rel_df):
    ## graph data
    data = HeteroData()
    
    data['api'].x = torch.Tensor(api_df['description_emb'].tolist())  
    data['param'].x = torch.Tensor(params_df['description_emb'].tolist()) 

    data['param', 'input', 'api'].edge_index = torch.tensor([[param_idx[i] for i in rel_df[rel_df['rel']=='input']['source'].tolist()],[api_idx[i] for i in rel_df[rel_df['rel']=='input']['target'].tolist()]],dtype=torch.long)
    data['param', 'input', 'api'].edge_weight = torch.tensor(rel_df[rel_df['rel']=='input']['all_cnt_rate'].tolist())  

    data['api', 'output', 'param'].edge_index = torch.tensor([[api_idx[i] for i in rel_df[rel_df['rel']=='output']['source'].tolist()],[param_idx[i] for i in rel_df[rel_df['rel']=='output']['target'].tolist()]],dtype=torch.long)
    data['api', 'output', 'param'].edge_weight = torch.tensor(rel_df[rel_df['rel']=='output']['all_cnt_rate'].tolist())  

    data['api', 'depends_on', 'api'].edge_index = torch.tensor([[api_idx[i] for i in rel_df[(rel_df['rel']=='depends_on')&(rel_df['source_type']=='api')]['source'].tolist()],[api_idx[i] for i in rel_df[(rel_df['rel']=='depends_on')&(rel_df['source_type']=='api')]['target'].tolist()]],dtype=torch.long)
    data['api', 'depends_on', 'api'].edge_weight = torch.tensor(rel_df[(rel_df['rel']=='depends_on')&(rel_df['source_type']=='api')]['all_cnt_rate'].tolist())  

    data['param', 'depends_on', 'param'].edge_index = torch.tensor([[param_idx[i] for i in rel_df[(rel_df['rel']=='depends_on')&(rel_df['source_type']=='param')]['source'].tolist()],[param_idx[i] for i in rel_df[(rel_df['rel']=='depends_on')&(rel_df['source_type']=='param')]['target'].tolist()]],dtype=torch.long)
    data['param', 'depends_on', 'param'].edge_weight = torch.tensor(rel_df[(rel_df['rel']=='depends_on')&(rel_df['source_type']=='param')]['all_cnt_rate'].tolist())  

    return data


def split_edge_label(data, edge_type):
    edge_label = data[edge_type].edge_label
    edge_label_index = data[edge_type].edge_label_index
    pos_mask = (edge_label == 1)
    neg_mask = (edge_label == 0)
    return edge_label_index[:, pos_mask], edge_label_index[:, neg_mask]


def data_split(data):

    transform = RandomLinkSplit(
        num_val=0.2,  
        num_test=0,  
        edge_types=data.edge_types, 
        is_undirected=False,
        add_negative_train_samples=True,
        neg_sampling_ratio=1, 
    )

    train_data, val_data, _ = transform(data) 

    
    for edge_type in data.edge_types:
        if hasattr(train_data[edge_type], 'edge_label_index') and hasattr(val_data[edge_type], 'edge_label_index'):
            train_edge_index = train_data[edge_type].edge_label_index
            val_edge_index = val_data[edge_type].edge_label_index
            
            original_edge_index = data[edge_type].edge_index
            original_edge_weight = data[edge_type].edge_weight
            
            val_mask = torch.ones(original_edge_index.size(1), dtype=torch.bool)
            for i in range(train_edge_index.size(1)):
                mask = (original_edge_index[0] == train_edge_index[0, i]) & (original_edge_index[1] == train_edge_index[1, i])
                val_mask &= ~mask  
            
            val_edge_weight = original_edge_weight[val_mask]
            val_data[edge_type].edge_weight = val_edge_weight.to(device)


    for data_ in [train_data, val_data]:
        for edge_type in data.edge_types:
            if 'edge_label' in data_[edge_type]:
                pos_edge_index, neg_edge_index = split_edge_label(data_, edge_type)
                data_[edge_type].edge_label_index = pos_edge_index
                data_[edge_type].neg_edge_label_index = neg_edge_index
                del data_[edge_type].edge_label 
                
                if hasattr(data_[edge_type], 'edge_weight'):
                    neg_edge_index = data_[edge_type].neg_edge_label_index
                    data_[edge_type].neg_edge_weight = torch.zeros(neg_edge_index.size(1), dtype=torch.float)

    return train_data, val_data


def edge_type_to_str(edge_type: tuple) -> str:
    return '__'.join(edge_type)


def hybrid_loss(pos_pred, neg_pred, mu=0.5, edge_weight=None, margin=1.0):
    
    if edge_weight is None:
        edge_weight = torch.ones_like(pos_pred)
        
    all_pred = torch.cat([pos_pred, neg_pred.reshape(-1)])
    all_label = torch.cat([edge_weight, torch.zeros_like(neg_pred.reshape(-1))])
    ce_loss = F.binary_cross_entropy_with_logits(all_pred, all_label)
    
    if edge_weight is not None:
        weight_factor = 1 + torch.sigmoid(edge_weight)  
        m_uv = margin * weight_factor
    else:
        m_uv = margin
    
    pos_expanded = pos_pred.unsqueeze(1)  
    neg_expanded = neg_pred.view(-1, pos_pred.size(0))  
    margin_loss = (m_uv.unsqueeze(1) - pos_expanded + neg_expanded).clamp(min=0).mean()
    
    return mu * ce_loss + (1 - mu) * margin_loss


def train_one_epoch(data,mu):
    model.train()
    total_loss = torch.tensor(0.0, device=device)
    
    edge_index_dict = {}
    edge_weight_dict = {}
    for edge_type in data.edge_types:
        edge_index_dict[edge_type] = data[edge_type].edge_index.to(device)
        if 'edge_weight' in data[edge_type]:
            edge_weight_dict[edge_type] = data[edge_type].edge_weight.to(device)
    
    x_dict = model(
        x_dict={nt: data[nt].x.to(device) for nt in data.node_types},
        edge_index_dict=edge_index_dict,
        edge_weight_dict=edge_weight_dict if edge_weight_dict else None
    )
    

    for edge_type in data.edge_types:
        pos_edge_index = data[edge_type].edge_label_index.to(device)
        neg_edge_index = data[edge_type].neg_edge_label_index.to(device)
        
        src_type, _, dst_type = edge_type
        pos_src, pos_dst = pos_edge_index[0], pos_edge_index[1]
        pos_edge_weight = data[edge_type].edge_weight.to(device) if 'edge_weight' in data[edge_type] else None
        
  
        pos_pred = model.predict_link(x_dict[src_type][pos_src], x_dict[dst_type][pos_dst], pos_edge_weight)
        neg_src, neg_dst = neg_edge_index[0], neg_edge_index[1]
        neg_edge_weight = torch.zeros(neg_edge_index.size(1), dtype=torch.float).to(device)  
        neg_pred = model.predict_link(x_dict[src_type][neg_src], x_dict[dst_type][neg_dst], neg_edge_weight)
        

        loss = hybrid_loss(
            pos_pred.squeeze(), 
            neg_pred.squeeze(),
            mu, 
            edge_weight=pos_edge_weight, 
            margin=1.5 
        )
        total_loss += loss  
    
 
    optimizer.zero_grad()
    loss = total_loss / len(data.edge_types)  
    loss.backward()
    optimizer.step()
    
    return loss.item()  



if __name__ == "__main__":
    
    # bge
    path = 'your_path/bge-large-en-v1.5' 
    model = SentenceTransformer(path)

    # data
    api_df = pd.read_csv('./data/api_df_0408.csv')  
    params_df = pd.read_csv('./data/params_df_0408.csv')  
    rel_df = pd.read_csv('./data/params_rel_df_0408.csv')

    api_df,params_df,api_idx,param_idx = data_preparation(api_df,params_df,model)

    data = data_insert_graph(api_df,params_df,api_idx,param_idx,rel_df)
    train_data, val_data = data_split(data)
    
    # train
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

    mu_0 = 0.5
    gamma = 0.99
    epochs = 300

    in_features = 1028    
    conv1_in = 1028        
    conv1_out = 512        
    conv2_in = 512        
    conv2_out = 256       

    model = LinkPredictionModel(
        in_features=in_features,
        conv1_in=conv1_in,
        conv1_out=conv1_out,
        conv2_in=conv2_in,
        conv2_out=conv2_out,
        metadata=train_data.metadata()
    ).to(device)
    
    for epoch in range(epochs):
        mu = mu_0 * (gamma ** epoch)
        loss = train_one_epoch(train_data,mu)
        if epoch%50==0:
            print('mu:',mu)
            print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
