import tqdm
from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import re  
import time
import warnings
import numpy as np
import json
from sklearn.metrics import r2_score
import shap
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from utils.tools import EarlyStopping, adjust_learning_rate
from data_provider.data_loader import Dataset_Custom
from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import pandas as pd  # 导入pandas库
from tqdm.auto import tqdm  # Change this import statement
class Exp_Long_Term_Forecast(Exp_Basic):
    def __init__(self, args):
        super(Exp_Long_Term_Forecast, self).__init__(args)
        self.args = args  # 确保 args 被正确传递

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        if self.args.loss == 'MSE' or self.args.loss == 'mse':
            criterion = nn.MSELoss()
        elif self.args.loss == 'MAE' or self.args.loss == 'mae':
            criterion = nn.L1Loss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            preds=[]
            trues=[]
            for i, (batch_x, batch_y) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device,non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:,:].float()
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:

                    outputs = self.model(batch_x)
                pred = outputs.detach().cpu().numpy()
                true = batch_y.detach().numpy()
                preds.append(pred)
                trues.append(true)
        if len(preds)>0:
            preds=np.concatenate(preds, axis=0)
            trues=np.concatenate(trues, axis=0)
        else:
            preds=preds[0]
            trues=trues[0]
        mse,mae= metric(preds, trues)
        vali_loss=mae if criterion == 'MAE' or criterion == 'mae' else mse
        self.model.train()
        torch.cuda.empty_cache()
        return vali_loss

    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad(set_to_none=True)
                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:, :].float().to(self.device, non_blocking=True)
                
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    outputs = self.model(batch_x)
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())
                
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()
                torch.cuda.empty_cache()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, self.args.loss)
            test_loss = self.vali(test_data, test_loader, self.args.loss)
            print("Epoch: {}, Steps: {} | Train Loss: {:.3f}  vali_loss: {:.3f}   test_loss: {:.3f}".format(epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args)

        self.model.eval()
        with torch.no_grad():
            train_results = self.export_predictions('train', train_loader, train_data, path)
            vali_results = self.export_predictions('val', vali_loader, vali_data, path)
            test_results = self.export_predictions('test', test_loader, test_data, path)

        # Combine results into one DataFrame
        combined_results = pd.concat([train_results, vali_results, test_results], keys=['Train', 'Validation', 'Test'])

        # Save combined results to a single CSV file
        combined_results.to_csv(os.path.join(path, 'combined_results.csv'))

        print('Combined results saved to', os.path.join(path, 'combined_results.csv'))

        torch.cuda.empty_cache()

 

    def export_predictions(self, dataset_name, loader, data_set, path):
        preds, trues = [], []
        for i, (batch_x, batch_y) in enumerate(loader):
            batch_x = batch_x.float().to(self.device, non_blocking=True)
            batch_y = batch_y[:, -self.args.pred_len:, :].float()
            
            if self.args.use_amp:
                with torch.cuda.amp.autocast():
                    outputs = self.model(batch_x)
            else:
                outputs = self.model(batch_x)
            
            outputs = outputs.detach().cpu().numpy()
            batch_y = batch_y.detach().numpy()

            preds.append(outputs)
            trues.append(batch_y)

        preds = np.concatenate(preds, axis=0)
        trues = np.concatenate(trues, axis=0)

        original_feature_num = data_set.data_x.shape[-1]
        target_col = data_set.data_y.shape[-1] - 1

        if preds.shape[-1] != original_feature_num:
            preds_adjusted = np.zeros((preds.shape[0], preds.shape[1], original_feature_num))
            preds_adjusted[:, :, -1] = preds.squeeze()
            preds = preds_adjusted

        if trues.shape[-1] != original_feature_num:
            trues_adjusted = np.zeros((trues.shape[0], trues.shape[1], original_feature_num))
            trues_adjusted[:, :, -1] = trues.squeeze()
            trues = trues_adjusted

        preds = preds.reshape(-1, original_feature_num)
        trues = trues.reshape(-1, original_feature_num)

        preds_inverse = data_set.inverse_transform(preds)
        trues_inverse = data_set.inverse_transform(trues)

        preds_inverse = preds_inverse[:, target_col]
        trues_inverse = trues_inverse[:, target_col]

        # Process the Model_Truth values
        model_truth = np.array([
            int(value) if int(value) == value or (value % 1) < 0.9 else int(value) + 1
            for value in trues_inverse.flatten()
        ])

        # Calculate the percentage decrease
        percentage_decreases = np.array([
            ((true - pred) / true * 100) if true != 0 else 0
            for true, pred in zip(model_truth, preds_inverse)
        ])
        abs_percentage_decreases = abs(np.array([
            ((true - pred) / true * 100) if true != 0 else 0
            for true, pred in zip(model_truth, preds_inverse)
        ]))

        # Read and match original data
        original_data = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
        original_trues = original_data.iloc[:, -1].values  # Use the last column as true values





        # Truncate result sizes to match lengths
        min_length = min(len(original_trues), len(model_truth), len(preds_inverse))
        model_truth = model_truth[:min_length]
        preds_inverse = preds_inverse[:min_length]
        percentage_decreases = percentage_decreases[:min_length]
        abs_percentage_decreases = abs_percentage_decreases[:min_length]

        # Print array lengths for debugging
        print(f"Lengths - Original: {len(original_trues)}, Model Truth: {len(model_truth)}, Predictions: {len(preds_inverse)}")

        results = pd.DataFrame({
            'Model_Truth': model_truth,
            'Prediction': preds_inverse.flatten(),
            'Percentage_Decrease': percentage_decreases,
            'Absolute_Percentage_Decrease': abs_percentage_decreases
        })

        results.to_csv(os.path.join(path, f'{dataset_name}_results.csv'), index=False)

        print(f'{dataset_name.capitalize()} results saved to', os.path.join(path, f'{dataset_name}_results.csv'))

        return results


    def shap_analysis(self):
        print("Starting SHAP analysis...")

        try:
            # Get the training data
            train_data, _ = self._get_data(flag='train')
            print(f"Train data size: {len(train_data)}")

            # Convert train_data to a tensor if it's not already
            train_data_tensor = torch.tensor(train_data.data_x, dtype=torch.float32).to(self.device)
            print(f"Train data tensor shape: {train_data_tensor.shape}")

            # Ensure the data has the correct shape (samples, features)
            if len(train_data_tensor.shape) != 2:
                raise ValueError(f"Unexpected shape of train_data_tensor: {train_data_tensor.shape}")

            # Get the dimensions
            n_samples, n_features = train_data_tensor.shape
            print(f"Data dimensions: samples={n_samples}, features={n_features}")

            # Convert the data to numpy
            train_data_2d = train_data_tensor.cpu().numpy()
            print(f"Train data 2d shape: {train_data_2d.shape}")

            sample_size = min(100, n_samples)
            sample_indices = np.random.choice(n_samples, sample_size, replace=False)
            sample_data = train_data_2d[sample_indices]
            print(f"Sample data shape: {sample_data.shape}")

            # Use a subset of sample data as background
            background_size = min(100, sample_size)
            background = sample_data[:background_size]
            print(f"Background shape: {background.shape}")

            # Ensure the model is in evaluation mode
            self.model.eval()

            # Check model outputs on background data
            with torch.no_grad():
                background_outputs = self.model(torch.tensor(background, dtype=torch.float32).to(self.device))
                print(f"Background outputs shape: {background_outputs.shape}")
                print(f"Background outputs: {background_outputs}")

            # Create a Deep SHAP explainer
            print("Creating DeepExplainer...")
            try:
                explainer = shap.DeepExplainer(self.model, background)
                print("DeepExplainer created successfully.")
            except Exception as e:
                print(f"Error creating DeepExplainer: {e}")
                raise

            # Calculate SHAP values for sample data in batches
            print("Calculating SHAP values...")
            batch_size = 100
            shap_values = []
            for i in tqdm(range(0, sample_data.shape[0], batch_size), desc="SHAP Calculation"):
                batch = sample_data[i:i+batch_size]
                batch_shap_values = explainer.shap_values(batch)
                shap_values.append(batch_shap_values)

            # Concatenate the results
            if isinstance(shap_values[0], list):
                # If multi-output, concatenate each output separately
                shap_values = [np.concatenate([batch[i] for batch in shap_values], axis=0) for i in range(len(shap_values[0]))]
            else:
                shap_values = np.concatenate(shap_values, axis=0)

            print("SHAP values calculated successfully.")

            print(f"SHAP values shape: {np.array(shap_values).shape}")

            # Handle multi-output case
            if isinstance(shap_values, list):
                shap_values = np.array(shap_values)

            if len(shap_values.shape) == 3:  # (output_dim, samples, features)
                shap_values = np.mean(shap_values, axis=0)  # Take the mean across output dimensions

            print(f"SHAP values shape after processing: {shap_values.shape}")

            # Calculate mean absolute SHAP values
            shap_values_mean_abs = np.mean(np.abs(shap_values), axis=0)
            shap_values_mean = np.mean(shap_values, axis=0)

            # Get feature names
            df_raw = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
            cols_data = df_raw.columns[2:-1]  # Exclude the last column (TB incidence)
            df_data = df_raw[cols_data]
            feature_names = df_data.columns.tolist()
            print(f"All feature names: {feature_names}")

            # Sort features by importance, excluding the last feature (TB incidence)
            feature_importance_order = np.argsort(shap_values_mean_abs[:-1])[::-1]
            top_10_features = feature_importance_order[:10]
            print(f"Top 10 features: {top_10_features}")

            # Print feature importance ranking
            print("Feature importance ranking and SHAP values:")
            print("Index: Feature Name - Mean Absolute SHAP Value")
            for i, idx in enumerate(feature_importance_order):
                shap_value_abs = shap_values_mean_abs[idx]
                print(f"{i + 1}: {feature_names[idx]} - shap_value_abs: {shap_value_abs:.6f} - shap_value: {shap_values_mean[idx]:.6f}")

            # Print SHAP values for each feature
            print("\nSHAP values for each feature:")
            for feature_idx in range(n_features - 1):  # Exclude the last feature
                shap_value = shap_values_mean[feature_idx]
                print(f"Feature: {feature_names[feature_idx]}, SHAP Value: {shap_value:.6f}")

            # Plot SHAP summary plot (violin plot)
            print("Plotting SHAP summary plot (violin)...")
            plt.figure(figsize=(12, 8))
            shap.summary_plot(shap_values[:, :-1], sample_data[:, :-1], 
                              feature_names=feature_names,
                              plot_type="violin", show=False)

            plt.xlabel("SHAP value (impact on model output)", family='Times New Roman', fontsize=14)
            plt.rc('font', family='Times New Roman', size=15)
            plt.tight_layout()
            plt.savefig('shap_summary_plot-1.png')
            plt.close()
            print("SHAP summary plot saved as 'shap_summary_plot-1.png'")

            # Plot feature importance
            print("Plotting feature importance...")
            plt.figure(figsize=(12, 8))
            top_20_features = feature_importance_order[:20]  # Top 20 features
            top_20_feature_names = [feature_names[idx] for idx in top_20_features]
            plt.barh(top_20_feature_names, shap_values_mean_abs[top_20_features])
            plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)", fontfamily='Times New Roman', fontsize=14)
            plt.ylabel("Features", fontfamily='Times New Roman', fontsize=14)
            plt.title("Feature Importance", fontfamily='Times New Roman', fontsize=16)
            plt.yticks(fontfamily='Times New Roman', fontsize=12)
            plt.xticks(fontfamily='Times New Roman', fontsize=12)
            plt.gca().invert_yaxis()
            plt.tight_layout()
            plt.savefig('feature_importance_plot.png')
            plt.close()
            print("Feature importance plot saved as 'feature_importance_plot.png'")


        except Exception as e:
            print(f"An error occurred during SHAP analysis: {e}")
            raise  # Re-raise the exception for debugging purposes






    def test(self, setting, test=1):
        test_data, test_loader = self._get_data(flag='test')
        path = os.path.join(self.args.checkpoints, setting)
        if test:
            print('loading model')
            self.model.load_state_dict(torch.load(os.path.join(path, 'checkpoint.pth')))
        
        head = f'./test_dict/{self.args.data_path[:-4]}/{self.args.seq_len}_to_{self.args.pred_len}/'
        tail = f'{self.args.model}/{self.args.loss}/bz_{self.args.batch_size}/lr_{self.args.learning_rate}/'
        dict_path = head + tail
        
        if not os.path.exists(dict_path):
            os.makedirs(dict_path)

        self.model.eval()
        
        with torch.no_grad():
            preds = []
            trues = []
            
            for i, (batch_x, batch_y) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:, :].float()
                
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:
                    outputs = self.model(batch_x)
                
                outputs = outputs.detach().cpu().numpy()
                batch_y = batch_y.detach().numpy()

                preds.append(outputs)
                trues.append(batch_y)
            
            preds = np.concatenate(preds, axis=0)
            trues = np.concatenate(trues, axis=0)

            print('test shape:', preds.shape, trues.shape)

            # 获取原始数据的特征数量
            original_feature_num = test_data.data_x.shape[-1]

            # 确保预测和真实值的维度与原始数据匹配
            if preds.shape[-1] != original_feature_num:
                print(f"Warning: Prediction shape {preds.shape} does not match original data shape. Adjusting...")
                preds_adjusted = np.zeros((preds.shape[0], preds.shape[1], original_feature_num))
                preds_adjusted[:,:,-1] = preds.squeeze()  # 假设预测值对应最后一个特征
                preds = preds_adjusted

            if trues.shape[-1] != original_feature_num:
                print(f"Warning: True values shape {trues.shape} does not match original data shape. Adjusting...")
                trues_adjusted = np.zeros((trues.shape[0], trues.shape[1], original_feature_num))
                trues_adjusted[:,:,-1] = trues.squeeze()  # 假设真实值对应最后一个特征
                trues = trues_adjusted

            # Reshape the arrays into 2D
            preds = preds.reshape(-1, original_feature_num)
            trues = trues.reshape(-1, original_feature_num)

            # Inverse transform the predictions and truths
            preds_inverse = test_data.inverse_transform(preds)
            trues_inverse = test_data.inverse_transform(trues)

            # 只取目标变量的列
            target_col = test_data.data_y.shape[-1] - 1  # 假设目标变量是最后一列
            preds_inverse = preds_inverse[:, target_col]
            trues_inverse = trues_inverse[:, target_col]

            # Save to CSV
            results = pd.DataFrame({
                'Truth': trues_inverse.flatten(),
                'Prediction': preds_inverse.flatten()
            })
            results.to_csv(os.path.join(dict_path, 'test_results.csv'), index=False)

            print('Results saved to', os.path.join(dict_path, 'test_results.csv'))

            # 新增: 读取原始数据并进行匹配
            original_data = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
            target_values = original_data[self.args.target].values

            # Compare and extract exact matches based on value
            matching_indices = np.where(np.isin(trues_inverse, target_values))[0]
            matching_results = results.iloc[matching_indices]

            # Save matching results to new CSV
            matching_results.to_csv(os.path.join(dict_path, 'matching_test_results.csv'), index=False)

            # 计算性能指标
            mse, mae = mean_squared_error(trues, preds), mean_absolute_error(trues, preds)
            r2 = r2_score(trues, preds)
            print('mse: {:.3f}  mae: {:.3f}  r2: {:.3f}'.format(mse, mae, r2))

            # 保存性能指标到JSON文件
            my_dict = {
                'mse': "{:.3f}".format(mse),
                'mae': "{:.3f}".format(mae),
                'r2': "{:.3f}".format(r2)
            }
            with open(os.path.join(dict_path, 'records.json'), 'w') as f:
                json.dump(my_dict, f)

            # 清理GPU缓存
            torch.cuda.empty_cache()

        return



    def predict(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        val_data, val_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')
        path = os.path.join(self.args.checkpoints, setting)

        print('loading model')
        self.model.load_state_dict(torch.load(os.path.join(path, 'checkpoint.pth')))
        
        head = f'./test_dict/{self.args.data_path[:-4]}/{self.args.seq_len}_to_{self.args.pred_len}/'
        tail = f'{self.args.model}/{self.args.loss}/bz_{self.args.batch_size}/lr_{self.args.learning_rate}/'
        dict_path = head + tail
        
        if not os.path.exists(dict_path):
            os.makedirs(dict_path)

        self.model.eval()
        
        with torch.no_grad():
            preds = []
            trues = []
            
            for i, (batch_x, batch_y) in enumerate(train_loader+val_loader+test_loader):
                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y[:, -self.args.pred_len:, :].float()
                
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(batch_x)
                else:
                    outputs = self.model(batch_x)
                
                outputs = outputs.detach().cpu().numpy()
                batch_y = batch_y.detach().numpy()

                preds.append(outputs)
                trues.append(batch_y)
            
            preds = np.concatenate(preds, axis=0)
            trues = np.concatenate(trues, axis=0)

            print('all shape:', preds.shape, trues.shape)

            # 获取原始数据的特征数量
            original_feature_num = test_data.data_x.shape[-1]

            # 确保预测和真实值的维度与原始数据匹配
            if preds.shape[-1] != original_feature_num:
                print(f"Warning: Prediction shape {preds.shape} does not match original data shape. Adjusting...")
                preds_adjusted = np.zeros((preds.shape[0], preds.shape[1], original_feature_num))
                preds_adjusted[:,:,-1] = preds.squeeze()  # 假设预测值对应最后一个特征
                preds = preds_adjusted

            if trues.shape[-1] != original_feature_num:
                print(f"Warning: True values shape {trues.shape} does not match original data shape. Adjusting...")
                trues_adjusted = np.zeros((trues.shape[0], trues.shape[1], original_feature_num))
                trues_adjusted[:,:,-1] = trues.squeeze()  # 假设真实值对应最后一个特征
                trues = trues_adjusted

            # Reshape the arrays into 2D
            preds = preds.reshape(-1, original_feature_num)
            trues = trues.reshape(-1, original_feature_num)

            # Inverse transform the predictions and truths
            preds_inverse = test_data.inverse_transform(preds)
            trues_inverse = test_data.inverse_transform(trues)

            # 只取目标变量的列
            target_col = test_data.data_y.shape[-1] - 1  # 假设目标变量是最后一列
            preds_inverse = preds_inverse[:, target_col]
            trues_inverse = trues_inverse[:, target_col]

            # Save to CSV
            results = pd.DataFrame({
                'Truth': trues_inverse.flatten(),
                'Prediction': preds_inverse.flatten()
            })
            results.to_csv(os.path.join(dict_path, 'test_results.csv'), index=False)

            print('Results saved to', os.path.join(dict_path, 'test_results.csv'))

            # 新增: 读取原始数据并进行匹配
            original_data = pd.read_csv(os.path.join(self.args.root_path, self.args.data_path))
            target_values = original_data[self.args.target].values

            # Compare and extract exact matches based on value
            matching_indices = np.where(np.isin(trues_inverse, target_values))[0]
            matching_results = results.iloc[matching_indices]

            # Save matching results to new CSV
            matching_results.to_csv(os.path.join(dict_path, 'matching_test_results.csv'), index=False)

            # 计算性能指标
            mse, mae = mean_squared_error(trues, preds), mean_absolute_error(trues, preds)
            r2 = r2_score(trues, preds)
            print('mse: {:.3f}  mae: {:.3f}  r2: {:.3f}'.format(mse, mae, r2))

            # 保存性能指标到JSON文件
            my_dict = {
                'mse': "{:.3f}".format(mse),
                'mae': "{:.3f}".format(mae),
                'r2': "{:.3f}".format(r2)
            }
            with open(os.path.join(dict_path, 'records.json'), 'w') as f:
                json.dump(my_dict, f)

            # 清理GPU缓存
            torch.cuda.empty_cache()

        return

