# common/visualize.py - Panel-based prediction visualization and metric calculation

from __future__ import annotations
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from typing import Dict, List, Tuple
from datetime import datetime
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score


# 기본 폰트 및 스타일 설정
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 16


def calculate_cvrmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """CVRMSE (Coefficient of Variation of Root Mean Square Error) 계산"""
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mean_true = np.mean(y_true)
    return rmse / mean_true if mean_true != 0 else float('inf')


class PanelPredictVisualize:
    """
    분전함별 예측 결과 시각화 및 메트릭 계산 클래스
    """
    
    def __init__(self, prediction_result: Dict):
        """
        Args:
            prediction_result: predict_panel 함수의 반환값
        """
        self.pred_result = prediction_result
        self.predictions = prediction_result['predictions']  # [144, num_features]
        self.ground_truth = prediction_result['ground_truth']  # [144, num_features]
        self.past_data = prediction_result['past_data']  # [288, num_features]
        self.feature_names = prediction_result['feature_names']
        self.timestamps = prediction_result['timestamps']
        self.scaler = prediction_result['scaler']
        self.panel_info = prediction_result['panel_info']
        
        # 시간 정보
        self.past_times = self.timestamps['past']
        self.future_times = self.timestamps['future'] 
        self.start_time = self.timestamps['start']
        
    def denormalize_data(self, data: np.ndarray) -> np.ndarray:
        """정규화된 데이터를 원래 스케일로 복원"""
        return self.scaler.inverse_transform(data)
    
    def calculate_metrics(self) -> pd.DataFrame:
        """각 서브부하별 메트릭 계산 (Loss, MAE, R2, CVRMSE)"""
        
        # 정규화 해제
        pred_denorm = self.denormalize_data(self.predictions)
        gt_denorm = self.denormalize_data(self.ground_truth)
        
        metrics_data = []
        
        for i, feature_name in enumerate(self.feature_names):
            y_true = gt_denorm[:, i]
            y_pred = pred_denorm[:, i]
            
            # 메트릭 계산
            mse = mean_squared_error(y_true, y_pred)
            mae = mean_absolute_error(y_true, y_pred)
            r2 = r2_score(y_true, y_pred)
            cvrmse = calculate_cvrmse(y_true, y_pred)
            
            metrics_data.append({
                'Subload': feature_name,
                'Loss(MSE)': mse,
                'MAE': mae,
                'R2': r2,
                'CVRMSE': cvrmse
            })
        
        # 전체 평균 메트릭 계산
        all_true = gt_denorm.flatten()
        all_pred = pred_denorm.flatten()
        
        avg_mse = mean_squared_error(all_true, all_pred)
        avg_mae = mean_absolute_error(all_true, all_pred)
        avg_r2 = r2_score(all_true, all_pred)
        avg_cvrmse = calculate_cvrmse(all_true, all_pred)
        
        metrics_data.append({
            'Subload': 'Average',
            'Loss(MSE)': avg_mse,
            'MAE': avg_mae, 
            'R2': avg_r2,
            'CVRMSE': avg_cvrmse
        })
        
        return pd.DataFrame(metrics_data)
    
    def plot_subloads(self, figsize: Tuple[int, int] = (20, 15), save_path: str = None):
        """각 서브부하별 시각화"""
        
        num_features = len(self.feature_names)
        
        # subplot 구성 (최대 4열)
        ncols = min(4, num_features)
        nrows = (num_features + ncols - 1) // ncols
        
        fig, axes = plt.subplots(nrows, ncols, figsize=figsize, dpi=150)
        
        # 1차원 배열로 변환
        if num_features == 1:
            axes = [axes]
        elif nrows == 1:
            axes = axes
        else:
            axes = axes.flatten()
        
        # 전체 제목
        panel_name = self.panel_info['name']
        energy_use = self.panel_info['energy_use']
        start_dt = self.panel_info['start_datetime']
        fig.suptitle(f'{panel_name} - {energy_use}\nPrediction from {start_dt}', 
                    fontsize=16, fontweight='bold')
        
        # 데이터 정규화 해제
        pred_denorm = self.denormalize_data(self.predictions)
        gt_denorm = self.denormalize_data(self.ground_truth)
        past_denorm = self.denormalize_data(self.past_data)
        
        for i, feature_name in enumerate(self.feature_names):
            ax = axes[i]
            
            # 과거 데이터 (288 타임스텝)
            ax.plot(self.past_times, past_denorm[:, i], 
                   color='blue', linewidth=1.5, label='Historical Data', alpha=0.8)
            
            # Ground Truth (144 타임스텝)  
            ax.plot(self.future_times, gt_denorm[:, i],
                   color='blue', linestyle='--', linewidth=1.5, 
                   label='Ground Truth', alpha=0.8)
            
            # Predicted (144 타임스텝)
            ax.plot(self.future_times, pred_denorm[:, i],
                   color='red', linewidth=2, label='Predicted')
            
            # 예측 시작 시점 표시
            ax.axvline(self.start_time, color='green', linestyle=':', 
                      linewidth=2, label='Prediction Start')
            
            # 미래 구간 배경 하이라이트
            ax.fill_betweenx([ax.get_ylim()[0], ax.get_ylim()[1]], 
                           self.start_time, self.future_times[-1],
                           color='orange', alpha=0.1, label='Prediction Horizon')
            
            # 축 설정
            ax.set_title(f'Subload: {feature_name}', fontweight='bold')
            ax.set_ylabel('Power (kW)')
            ax.grid(True, alpha=0.3)
            ax.legend(loc='upper left', fontsize=9)
            
            # x축 시간 포맷팅
            ax.xaxis.set_major_locator(mdates.HourLocator(interval=6))
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
            ax.tick_params(axis='x', rotation=45)
        
        # 빈 subplot 제거
        for i in range(num_features, len(axes)):
            fig.delaxes(axes[i])
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"📊 Plot saved to: {save_path}")
        
        plt.show()
    
    def plot_aggregated(self, figsize: Tuple[int, int] = (16, 8), save_path: str = None):
        """모든 서브부하를 합친 집계 그래프"""
        
        # 데이터 정규화 해제
        pred_denorm = self.denormalize_data(self.predictions)
        gt_denorm = self.denormalize_data(self.ground_truth)  
        past_denorm = self.denormalize_data(self.past_data)
        
        # 각 시점별 총합 계산
        past_total = np.sum(past_denorm, axis=1)
        gt_total = np.sum(gt_denorm, axis=1)
        pred_total = np.sum(pred_denorm, axis=1)
        
        fig, ax = plt.subplots(figsize=figsize, dpi=150)
        
        # 그래프 그리기
        ax.plot(self.past_times, past_total,
               color='blue', linewidth=2, label='Historical Total', alpha=0.8)
        
        ax.plot(self.future_times, gt_total,
               color='blue', linestyle='--', linewidth=2,
               label='Ground Truth Total', alpha=0.8)
        
        ax.plot(self.future_times, pred_total,
               color='red', linewidth=2.5, label='Predicted Total')
        
        # 예측 시작점 표시
        ax.axvline(self.start_time, color='green', linestyle=':', 
                  linewidth=2, label='Prediction Start')
        
        # 예측 구간 배경
        ax.fill_betweenx([ax.get_ylim()[0], ax.get_ylim()[1]],
                        self.start_time, self.future_times[-1], 
                        color='orange', alpha=0.1, label='Prediction Horizon')
        
        # 제목 및 축 설정
        panel_name = self.panel_info['name']
        energy_use = self.panel_info['energy_use']
        start_dt = self.panel_info['start_datetime']
        
        ax.set_title(f'{panel_name} - {energy_use} (Total Aggregate)\nPrediction from {start_dt}',
                    fontsize=14, fontweight='bold')
        ax.set_ylabel('Total Power (kW)')
        ax.grid(True, alpha=0.3)
        ax.legend()
        
        # x축 포맷팅
        ax.xaxis.set_major_locator(mdates.HourLocator(interval=4))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
        ax.tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"📊 Aggregated plot saved to: {save_path}")
        
        plt.show()
    
    def print_metrics(self):
        """메트릭을 깔끔하게 출력"""
        
        metrics_df = self.calculate_metrics()
        
        print("\n" + "="*80)
        print(f"📊 PREDICTION METRICS")
        print("="*80)
        print(f"Panel: {self.panel_info['name']}")
        print(f"Energy Use: {self.panel_info['energy_use']}")
        print(f"Start DateTime: {self.panel_info['start_datetime']}")
        print(f"Prediction Window: 288 past → 144 future timesteps")
        print("-"*80)
        
        # 테이블 형태로 출력
        print(f"{'Subload':<12} {'Loss(MSE)':<12} {'MAE':<10} {'R2':<8} {'CVRMSE':<10}")
        print("-"*80)
        
        for _, row in metrics_df.iterrows():
            subload = str(row['Subload'])[:10]  # 이름 길이 제한
            loss = row['Loss(MSE)']
            mae = row['MAE']
            r2 = row['R2']
            cvrmse = row['CVRMSE']
            
            if subload == 'Average':
                print("-"*80)
                print(f"{'AVERAGE':<12} {loss:<12.4f} {mae:<10.2f} {r2:<8.3f} {cvrmse:<10.3f}")
            else:
                print(f"{subload:<12} {loss:<12.4f} {mae:<10.2f} {r2:<8.3f} {cvrmse:<10.3f}")
        
        print("="*80)
        
        return metrics_df
    
    def save_results(self, output_dir: str = None):
        """결과를 파일로 저장"""
        
        if output_dir is None:
            output_dir = f"results_{self.panel_info['name']}_{self.panel_info['energy_use']}"
        
        from pathlib import Path
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        
        # 메트릭 CSV 저장
        metrics_df = self.calculate_metrics()
        metrics_path = output_path / "metrics.csv"
        metrics_df.to_csv(metrics_path, index=False)
        
        # 시각화 저장
        plot_path = output_path / "subloads_plot.png"
        self.plot_subloads(save_path=str(plot_path))
        
        agg_plot_path = output_path / "aggregated_plot.png"  
        self.plot_aggregated(save_path=str(agg_plot_path))
        
        print(f"📁 Results saved to: {output_path}")
        return str(output_path)