import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
from typing import Dict, List, Optional, Callable, Union

class FraudDataVisualizer:
    """简化的诈骗数据可视化工具"""
    
    def __init__(self):
        # 预定义的自定义指标计算函数
        self.custom_indicators = {
            "fraud_success_rate": lambda df: np.where(
                df["total_fraud"] > 0, 
                (df["private_transfer_money"]) / (df["private_transfer_money"] + df["fraud_fail"]) * 100,
                0
            ),
            "fraud_intensity": lambda df: df["total_fraud"] / df["timestep"],
            "message_fraud_ratio": lambda df: np.where(
                df["total_private_messages"] > 0,
                df["bad_good_convos"] / df["total_private_messages"] * 100,
                0
            ),
        }
        
        # 默认样式配置
        self.default_style = {
            "figure_size": (10, 6),
            "colors": ['#91CAE8', 'orange', 'lightcoral', 'lightgreen', 'lightblue', 'gold'],
            "linestyles": ['--', '-.', ':', '-', '--', '-.'],
            "line_width": 2,
            "grid_alpha": 0.3,
            "font_size": 14,
        }
    
    def _process_dataframe(self, df: pd.DataFrame, name: str) -> pd.DataFrame:
        """对DataFrame计算自定义指标"""
        # 创建副本以避免修改原始DataFrame
        df_processed = df.copy()
        for indicator_name, calc_func in self.custom_indicators.items():
            try:
                df_processed[indicator_name] = calc_func(df_processed)
            except Exception as e:
                print(f"⚠️  Warning: Failed to calculate {indicator_name} for {name}: {e}")
        return df_processed

    def _load_csv(self, csv_path: str, name: str) -> Optional[pd.DataFrame]:
        """加载CSV文件并计算自定义指标"""
        try:
            df = pd.read_csv(csv_path)
            # 使用新的处理函数来计算指标
            df = self._process_dataframe(df, name)
            
            print(f"✅ Successfully loaded {name}: {len(df)} rows")
            return df
            
        except FileNotFoundError:
            print(f"❌ File not found: {csv_path}")
            return None
        except Exception as e:
            print(f"❌ Error loading {csv_path}: {e}")
            return None
    
    def _create_plot(self, title: str, xlabel: str = "Timestep", ylabel: str = "Value"):
        """创建标准化的图表"""
        fig, ax = plt.subplots(figsize=self.default_style["figure_size"])
        ax.set_xlabel(xlabel, fontsize=self.default_style["font_size"])
        ax.set_ylabel(ylabel, fontsize=self.default_style["font_size"])
        ax.grid(True, alpha=self.default_style["grid_alpha"])
        return fig, ax
    
    def _save_or_show(self, fig, output_path: Optional[str] = None):
        """保存图表或显示"""
        plt.tight_layout()
        if output_path:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close(fig)
            print(f"💾 Plot saved: {output_path}")
        else:
            plt.show()
    
    def plot_fraud_data(self, 
                       data_sources: Dict[str, Union[str, pd.DataFrame]],
                       indicators: List[str],
                       output_dir: Optional[str] = None,
                       mode: str = "compare_sources") -> None:
        """
        统一的诈骗数据可视化API
        
        Args:
            data_sources: {"实验名称": "csv文件路径" or pd.DataFrame}
            indicators: 要绘制的指标列表
            output_dir: 输出目录，None则显示图表
            mode: "compare_sources" | "compare_indicators"
        """
        
        # 加载所有数据，支持路径或DataFrame
        all_data = {}
        for name, source in data_sources.items():
            df = None
            if isinstance(source, str):
                df = self._load_csv(source, name)
            elif isinstance(source, pd.DataFrame):
                df = self._process_dataframe(source, name)
                print(f"✅ Successfully loaded DataFrame '{name}': {len(df)} rows")
            else:
                print(f"❌ Invalid data source type for '{name}': {type(source)}")

            if df is not None:
                all_data[name] = df
        
        if not all_data:
            print("❌ No valid data found")
            return
        
        if mode == "compare_sources":
            self._plot_compare_sources(all_data, indicators, output_dir)
        elif mode == "compare_indicators":
            self._plot_compare_indicators(all_data, indicators, output_dir)
        else:
            raise ValueError(f"Unknown mode: {mode}. Use 'compare_sources' or 'compare_indicators'")
    
    def _plot_compare_sources(self, all_data: Dict[str, pd.DataFrame], 
                             indicators: List[str], output_dir: Optional[str]):
        """按数据源对比模式：每个指标一张图，对比不同数据源"""
        
        for indicator in indicators:
            # 检查指标是否存在于所有数据中
            valid_sources = []
            for source_name, df in all_data.items():
                if indicator in df.columns:
                    valid_sources.append(source_name)
                else:
                    print(f"⚠️  {indicator} not found in {source_name}")
            
            if not valid_sources:
                print(f"❌ {indicator} not found in any data source")
                continue
            
            # 创建图表
            fig, ax = self._create_plot(
                title=f"{indicator.replace('_', ' ').title()} Comparison",
                ylabel=indicator.replace('_', ' ').title()
            )
            
            # 绘制每个数据源的数据
            for idx, source_name in enumerate(valid_sources):
                df = all_data[source_name]
                color = self.default_style["colors"][idx % len(self.default_style["colors"])]
                linestyle = self.default_style["linestyles"][idx % len(self.default_style["linestyles"])]
                
                # 绘制线条
                ax.plot(df['timestep'], df[indicator],
                       label=source_name,
                       color=color,
                       linestyle=linestyle,
                       linewidth=self.default_style["line_width"])
                
                # 添加最终值标注
                if len(df[indicator]) > 0:
                    final_value = df[indicator].iloc[-1]
                    if isinstance(final_value, (int, np.integer)):
                        annotation = f'{int(final_value)}'
                    else:
                        annotation = f'{final_value:.1f}'
                    
                    ax.annotate(annotation,
                              xy=(df['timestep'].iloc[-1], final_value),
                              xytext=(5, 0),
                              textcoords='offset points',
                              ha='left', va='center',
                              color=color)
            
            ax.legend(fontsize=self.default_style["font_size"])
            
            # 保存或显示
            output_path = None
            if output_dir:
                output_path = os.path.join(output_dir, f"{indicator}.png")
            self._save_or_show(fig, output_path)
    
    def _plot_compare_indicators(self, all_data: Dict[str, pd.DataFrame], 
                                indicators: List[str], output_dir: Optional[str]):
        """按指标对比模式：每个数据源一张图，对比不同指标"""
        
        for source_name, df in all_data.items():
            # 检查指标是否存在
            valid_indicators = [ind for ind in indicators if ind in df.columns]
            missing_indicators = [ind for ind in indicators if ind not in df.columns]
            
            if missing_indicators:
                print(f"⚠️  {source_name} missing indicators: {missing_indicators}")
            
            if not valid_indicators:
                print(f"❌ No valid indicators found in {source_name}")
                continue
            
            # 创建图表
            fig, ax = self._create_plot(
                title=f"{source_name} - Multiple Indicators",
                ylabel="Value"
            )
            
            # 绘制每个指标
            for idx, indicator in enumerate(valid_indicators):
                color = self.default_style["colors"][idx % len(self.default_style["colors"])]
                linestyle = self.default_style["linestyles"][idx % len(self.default_style["linestyles"])]
                
                ax.plot(df['timestep'], df[indicator],
                       label=indicator.replace('_', ' ').title(),
                       color=color,
                       linestyle=linestyle,
                       linewidth=self.default_style["line_width"])
                
                # 添加最终值标注
                if len(df[indicator]) > 0:
                    final_value = df[indicator].iloc[-1]
                    if isinstance(final_value, (int, np.integer)):
                        annotation = f'{int(final_value)}'
                    else:
                        annotation = f'{final_value:.1f}'
                    
                    ax.annotate(annotation,
                              xy=(df['timestep'].iloc[-1], final_value),
                              xytext=(5, 0),
                              textcoords='offset points',
                              ha='left', va='center',
                              color=color)
            
            ax.legend(fontsize=self.default_style["font_size"])
            
            # 保存或显示
            output_path = None
            if output_dir:
                output_path = os.path.join(output_dir, f"{source_name}_indicators.png")
            self._save_or_show(fig, output_path)
    
    def quick_fraud_overview(self, data_sources: Dict[str, str], output_dir: str):
        """快速生成诈骗数据概览图表"""
        
        key_indicators = [
            "private_transfer_money",
            "bad_good_convos", 
            "fraud_success_rate",
            "total_likes",
            "total_reposts",
            "total_good_comments"
        ]
        
        print("📊 Generating fraud data overview...")
        self.plot_fraud_data(
            data_sources=data_sources,
            indicators=key_indicators,
            output_dir=output_dir,
            mode="compare_sources"
        )
        print("✅ Overview complete!")
    
    def calculate_final_stats(self, csv_path: str) -> Dict[str, float]:
        """计算最终统计指标"""
        df = self._load_csv(csv_path, "stats")
        if df is None:
            return {}
        
        last_row = df.iloc[-1]
        stats = {
            "private_transfers": last_row.get('private_transfer_money', 0),
            "total_conversations": last_row.get('bad_good_convos', 0),
            "avg_message_depth": last_row.get('average_private_message_depth', 0.0),
            "fraud_success_rate": last_row.get('fraud_success_rate', 0.0),
        }
        
        print(f"\n📈 Final Statistics:")
        for key, value in stats.items():
            print(f"  {key.replace('_', ' ').title()}: {value}")
        
        return stats

# ============ 简化的使用接口 ============

def create_visualizer() -> FraudDataVisualizer:
    """创建可视化工具实例"""
    return FraudDataVisualizer()

def quick_plot(data_sources: Dict[str, str], 
               indicators: List[str],
               output_dir: str = None):
    """快捷绘图函数"""
    viz = create_visualizer()
    viz.plot_fraud_data(data_sources, indicators, output_dir)

# ============ 使用示例 ============

if __name__ == "__main__":
    # 创建可视化工具
    viz = create_visualizer()
  
    # 生成规模对比图表
    viz.plot_fraud_data(
        data_sources=scale_comparison_csv,
        indicators=["private_transfer_money", "fraud_success_rate", "bad_good_convos", "total_fraud"],
        output_dir="./outputs/scale_comparison",
        mode="compare_sources"
    )
