import ast
import re
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATConv, GraphConv, global_max_pool, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset, Subset
from torch_geometric.loader import DataLoader
import logging

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device: {device}')

class CodeGraphDataset(Dataset):
    def __init__(self, dataframe, scaler=None, node_type_vocab=None):
        self.invalid_count = 0
        self.dataframe = dataframe.reset_index(drop=True)
        self.scaler = scaler if scaler else MinMaxScaler()
        if scaler is None:  # 仅训练集拟合
            self.scaler.fit(self.dataframe['score'].values.reshape(-1, 1))
        # logging.info('Score values scaled using MinMaxScaler.')
        # Build a vocabulary for AST node types
        if node_type_vocab is None:
            self.node_type_vocab = self.build_node_type_vocab()
        else:
            self.node_type_vocab = node_type_vocab
        logging.info(f'Built node type vocabulary with size: {len(self.node_type_vocab)}')

    def build_node_type_vocab(self):
        node_types = set()
        for idx, code in enumerate(self.dataframe['code']):
            try:
                tree = ast.parse(code)
                for node in ast.walk(tree):
                    node_types.add(type(node).__name__)
            except Exception as e:
                logging.warning(f"Error parsing code at index {idx}: {e}")
        node_type_to_id = {"UNK": 0}
        for idx, nt in enumerate(sorted(node_types), start=1):
            node_type_to_id[nt] = idx
        return node_type_to_id

    def ast_to_graph(self, code):
        try:
            tree = ast.parse(code)
        except Exception as e:
            logging.warning(f"Error parsing code: {e}")
            return None

        nodes = []
        edges = []
        node_features = []
        node_id = 0
        node_id_map = {}

        def traverse(node, parent_id=None):
            nonlocal node_id
            current_id = node_id
            node_id_map[id(node)] = current_id
            nodes.append(current_id)
            # Encode node type as integer
            node_type = type(node).__name__
            node_type_id = self.node_type_vocab.get(node_type, 0)  # Handle unknown types
            node_features.append([node_type_id])
            node_id += 1

            if parent_id is not None:
                edges.append((parent_id, current_id))

            for child in ast.iter_child_nodes(node):
                traverse(child, current_id)

        traverse(tree)

        if not nodes:
            return None

        # Convert edges to a tensor
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)

        # Convert node features to a tensor
        x = torch.tensor(node_features, dtype=torch.long)

        # Create a Data object
        data = Data(x=x, edge_index=edge_index)
        return data

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        code = row['code']
        score = row['score']

        graph = self.ast_to_graph(code)
        if graph is None:
            # Skip samples with parsing errors by raising an exception
            # Alternatively, implement a different handling strategy
            logging.debug(f"Skipping index {idx} due to parsing error.")
            raise ValueError(f"Parsing failed for code at index {idx}.")

        if graph is None:
            self.invalid_count += 1
            logging.debug(f"Skipping index {idx} due to parsing error.")
            return None

        # Normalize score using the scaler
        score_normalized = self.scaler.transform([[score]]).flatten()[0]

        graph.y = torch.tensor([score_normalized], dtype=torch.float)
        return graph
    
class GNNModel(nn.Module):
    def __init__(self, num_node_types, embed_dim=64, hidden_dim=128, scaler=None):
        super(GNNModel, self).__init__()
        self.embedding = nn.Embedding(num_node_types, embed_dim)
        self.conv1 = GATConv(embed_dim, hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)  # 假设拼接了池化特征
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.scaler = scaler

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.embedding(x.squeeze())
        x = self.conv1(x, edge_index)
        x = self.dropout(F.relu(x))
        x = self.conv2(x, edge_index)
        x = self.dropout(F.relu(x))
        x = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x.squeeze()

class PassRatePredictor():
    def __init__(self, ini_data=None, model=None):
        self.model = model
        if ini_data is None:
            # 初始化数据集为空
            self.data = pd.DataFrame(columns=["code", "score"])
        self.scaler = MinMaxScaler()
        self.trained = False
        self.node_type_vocab = None

    def add_data(self, new_data, use_pass_rate=False):
        if isinstance(new_data, dict):
            new_data = pd.DataFrame.from_dict(new_data, orient='index').reset_index(drop=True)
            # 仅保留 'code' 和 'score' 列
            if use_pass_rate:
                new_data = new_data[['code', 'pass_rate']].rename(columns={'pass_rate': 'score'})
            else:
                new_data = new_data[['code', 'score']]

        if self.data is None:
            self.data = new_data
        else:
            # 过滤重复数据
            new_data = new_data[~new_data['code'].isin(self.data['code'])]
            self.data = pd.concat([self.data, new_data], ignore_index=True)

    def predict_score(self, new_code_samples, model=None, scaler=None):
        if model is None:
            model = self.model
        if scaler is None:
            scaler = self.scaler

        # 将新数据包装为DataFrame
        new_df = pd.DataFrame({
            "code": new_code_samples,
            "score": ["0s"] * len(new_code_samples)  # 占位值
        })
        
        df_clean, _ = self.clean_score_data(new_df)

        # 创建数据集
        dataset = CodeGraphDataset(df_clean, scaler=scaler, node_type_vocab=self.node_type_vocab)
        loader = DataLoader(
            [data for data in dataset if data is not None],
            batch_size=32
        )
        
        # 预测
        model.eval()
        preds = []
        with torch.no_grad():
            for batch in loader:
                pred = model(batch)
                preds.extend(pred.cpu().numpy())
        
        # 反归一化
        # print("Debug###########################################")
        # print(preds)
        # pred_score = scaler.inverse_transform(np.array(preds).reshape(-1, 1)).flatten()
        # print(pred_score)
        return preds
    
    def test_model(self, model, dataframe, train_scaler=None):
        # 使用训练集的scaler（假设已经通过train_model传递）
        df_clean, _ = self.clean_score_data(dataframe)
        if train_scaler is None:
            train_scaler = MinMaxScaler().fit(dataframe['score'].values.reshape(-1, 1))
        test_dataset = CodeGraphDataset(df_clean, scaler=train_scaler)
        test_loader = DataLoader(
            [data for data in test_dataset if data is not None],
            batch_size=32
        )
        
        criterion = torch.nn.MSELoss()
        model.eval()
        test_loss = []
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in test_loader:
                pred = model(batch)
                loss = criterion(pred, batch.y)
                test_loss.append(loss.item())
                all_preds.extend(pred.cpu().numpy())
                all_labels.extend(batch.y.cpu().numpy())
        
        # 反归一化预测值和真实值
        preds = test_dataset.scaler.inverse_transform(np.array(all_preds).reshape(-1, 1)).flatten()
        labels = test_dataset.scaler.inverse_transform(np.array(all_labels).reshape(-1, 1)).flatten()
        
        # 计算指标
        mae = np.mean(np.abs(preds - labels))
        rmse = np.sqrt(np.mean((preds - labels)**2))
        print(f"Test MAE: {mae:.4f}, Test RMSE: {rmse:.4f}")
        return {"mae": mae, "rmse": rmse}
    
    def train_model(self, dataframe=None, epochs=50, batch_size=32, lr=0.001):
        if dataframe is None:
            dataframe = self.data

        # 清洗数据
        df_preprocessed = self.preprocess_data(dataframe)
        if df_preprocessed.empty:
            raise ValueError("无有效数据可供训练")

        # 划分训练集和验证集
        train_df, val_df = train_test_split(df_preprocessed, test_size=0.2, random_state=42)
        
        # 初始化数据集和DataLoader（训练集拟合scaler）
        train_scaler = MinMaxScaler().fit(train_df['score'].values.reshape(-1, 1))
        self.scaler = train_scaler
        train_dataset = CodeGraphDataset(train_df, scaler=train_scaler)
        val_dataset = CodeGraphDataset(val_df, scaler=train_scaler, node_type_vocab=train_dataset.node_type_vocab)  # 使用训练集的scaler
        
        # 过滤无效样本并创建DataLoader
        train_loader = DataLoader(
            [data for data in train_dataset if data is not None],
            batch_size=batch_size,
            shuffle=True
        )
        val_loader = DataLoader(
            [data for data in val_dataset if data is not None],
            batch_size=batch_size
        )
        
        # 初始化模型和优化器
        model = GNNModel(
            num_node_types=len(train_dataset.node_type_vocab),
            embed_dim=64,
            hidden_dim=128,
            scaler=train_scaler
        )
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = torch.nn.MSELoss()  # 均方误差损失
        
        # 训练循环
        best_val_loss = float('inf')
        for epoch in range(epochs):
            model.train()
            train_loss = []
            for batch in train_loader:
                optimizer.zero_grad()
                pred = model(batch)
                loss = criterion(pred, batch.y)
                loss.backward()
                optimizer.step()
                train_loss.append(loss.item())
            
            # 验证集评估
            model.eval()
            val_loss = []
            with torch.no_grad():
                for batch in val_loader:
                    pred = model(batch)
                    loss = criterion(pred, batch.y)
                    val_loss.append(loss.item())
            
            # 打印日志
            avg_train_loss = np.mean(train_loss)
            avg_val_loss = np.mean(val_loss)
            print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
            
            # 保存最佳模型
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), "best_gnn_model.pth")
        
        self.model = model
        self.node_type_vocab = train_dataset.node_type_vocab
        return model
    
    def filter_invalid_ast(self, df):
        valid_indices = []
        invalid_indices = []
        
        for idx, code in enumerate(df['code']):
            try:
                ast.parse(code)
                valid_indices.append(idx)
            except Exception as e:
                logging.warning(f"索引 {idx} 的代码无法解析AST: {e}")
                invalid_indices.append(idx)
        
        # 保留有效样本
        df_valid = df.iloc[valid_indices].reset_index(drop=True)
        return df_valid, invalid_indices

    def clean_score_data(self, df):
        cleaned_scores = []
        invalid_indices = []
        
        for idx, row in df.iterrows():
            value = row['score']
            try:
                if isinstance(value, str):
                    # 移除空格，转换为小写
                    cleaned_str = value.strip().lower()
                    # 提取数值和单位（正则匹配数值部分）
                    num_match = re.match(r"^(\d+\.?\d*)\s*([a-z]*)?", cleaned_str)
                    if not num_match:
                        raise ValueError("无法提取数值")
                    num = float(num_match.group(1))
                    unit = num_match.group(2) or 's'  # 默认单位是秒
                    # 根据单位转换为秒
                    if unit in {'s', 'sec', 'second', ''}:
                        converted = num
                    elif unit in {'ms', 'msec', 'millisecond'}:
                        converted = num / 1000
                    elif unit in {'m', 'min', 'minute'}:
                        converted = num * 60
                    elif unit in {'h', 'hour'}:
                        converted = num * 3600
                    else:
                        logging.warning(f"索引 {idx} 的未知单位 '{unit}'，假设为秒")
                        converted = num
                    cleaned_scores.append(converted)
                else:
                    # 处理数值类型（int/float）
                    cleaned_scores.append(float(value))
            except Exception as e:
                logging.warning(f"索引 {idx} 的score值 '{value}' 处理失败: {e}")
                invalid_indices.append(idx)
                cleaned_scores.append(None)
        
        # 替换原列并删除无效行
        df_clean = df.copy()
        df_clean['score'] = cleaned_scores
        df_clean = df_clean.dropna(subset=['score']).reset_index(drop=True)
        return df_clean, invalid_indices

    def preprocess_data(self, df):
        # Step 1: 过滤无法解析AST的样本
        df_ast_valid, ast_invalid = self.filter_invalid_ast(df)
        logging.info(f"过滤 {len(ast_invalid)} 个无效AST样本")
        
        # Step 2: 清洗score字段
        df_clean, score_invalid = self.clean_score_data(df_ast_valid)
        logging.info(f"过滤 {len(score_invalid)} 个无效score样本")
        
        return df_clean

###############################################################
# # debug

# # 1. 加载数据
# df = pd.read_csv(r"E:\python_project_new\AI4SLCDP\leetcode_data\leetcode Median of Two Sorted Arrays.csv")

# # 将"runtime" 列改为"score"
# df.rename(columns={"runtime": "score"}, inplace=True)
# print(df.head())

# pass_rate_predictor = PassRatePredictor()
# pass_rate_predictor.add_data(df)

# model = pass_rate_predictor.train_model(epochs=100)

# # 3. 测试模型
# # test_df = pd.read_csv("test_data.csv")
# # test_metrics = test_model(model, test_df)

# # 4. 预测新样本
# new_samples = [
#     "def square(x):\n    return x ** 2",
#     "def div(a, b):\n    return a / b"
# ]

# predictions = pass_rate_predictor.predict_score(new_samples)
# print(f"Predicted scores: {predictions}")