from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from models import Informer, Autoformer, Transformer, DLinear, Linear, NLinear, PatchTST, SegRNN, CycleNet, QuLTSF, \
    iTransformer, TimeXer, PQNet
from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
from utils.metrics import metric

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler

import os
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

warnings.filterwarnings('ignore')


class Exp_Main(Exp_Basic):
    def __init__(self, args):
        super(Exp_Main, self).__init__(args)

    def _build_model(self):
        model_dict = {
            'Autoformer': Autoformer,
            'Transformer': Transformer,
            'Informer': Informer,
            'DLinear': DLinear,
            'NLinear': NLinear,
            'Linear': Linear,
            'PatchTST': PatchTST,
            'SegRNN': SegRNN,
            'CycleNet': CycleNet,
            'QuLTSF': QuLTSF,
            'iTransformer': iTransformer,
            'TimeXer': TimeXer,
            'PQNet': PQNet,
        }
        model = model_dict[self.args.model].Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_cycle) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                batch_cycle = batch_cycle.int().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                            outputs = self.model(batch_x, batch_cycle)
                        elif any(substr in self.args.model for substr in
                                 {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                        outputs = self.model(batch_x, batch_cycle)
                    elif any(substr in self.args.model for substr in {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                f_dim = -1 if self.args.features == 'MS' else 0
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)

                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim,
                                            steps_per_epoch=train_steps,
                                            pct_start=self.args.pct_start,
                                            epochs=self.args.train_epochs,
                                            max_lr=self.args.learning_rate)

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            # max_memory = 0
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_cycle) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()
                batch_x = batch_x.float().to(self.device)

                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                batch_cycle = batch_cycle.int().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)

                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                            outputs = self.model(batch_x, batch_cycle)
                        elif any(substr in self.args.model for substr in
                                 {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

                        f_dim = -1 if self.args.features == 'MS' else 0
                        outputs = outputs[:, -self.args.pred_len:, f_dim:]
                        batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                        outputs = self.model(batch_x, batch_cycle)
                    elif any(substr in self.args.model for substr in {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]

                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y)
                    # print(outputs.shape,batch_y.shape)
                    f_dim = -1 if self.args.features == 'MS' else 0
                    outputs = outputs[:, -self.args.pred_len:, f_dim:]
                    batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()

                # current_memory = torch.cuda.max_memory_allocated() / 1024 ** 2
                # max_memory = max(max_memory, current_memory)

                if self.args.lradj == 'TST':
                    adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
                    scheduler.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            if self.args.lradj != 'TST':
                adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args)
            else:
                print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0]))

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        # print(f"Max Memory (MB): {max_memory}")

        return self.model

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')

        if test:
            print('loading model')
            self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))

        preds = []
        trues = []
        inputx = []
        folder_path = './test_results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_cycle) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                batch_cycle = batch_cycle.int().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                            outputs = self.model(batch_x, batch_cycle)
                        elif any(substr in self.args.model for substr in
                                 {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                        outputs = self.model(batch_x, batch_cycle)
                    elif any(substr in self.args.model for substr in {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]

                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

                f_dim = -1 if self.args.features == 'MS' else 0
                # print(outputs.shape,batch_y.shape)
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                outputs = outputs.detach().cpu().numpy()
                batch_y = batch_y.detach().cpu().numpy()

                pred = outputs  # outputs.detach().cpu().numpy()  # .squeeze()
                true = batch_y  # batch_y.detach().cpu().numpy()  # .squeeze()
                
                preds.append(pred)
                trues.append(true)
                # print("batch_size=1:",f"preds:{preds},trues:{trues}")
                
                # inputx.append(batch_x.detach().cpu().numpy())
                if i % 20 == 0:
                    input = batch_x.detach().cpu().numpy()

                    gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
                    pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)

                    visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
                    # np.savetxt(os.path.join(folder_path, str(i) + '.txt'), pd)
                    # np.savetxt(os.path.join(folder_path, str(i) + 'true.txt'), gt)

        if self.args.test_flop:
            test_params_flop(self.model, (batch_x.shape[1], batch_x.shape[2]))
            exit()
        preds = np.concatenate(preds, axis=0)
        trues = np.concatenate(trues, axis=0)
        # inputx = np.concatenate(inputx, axis=0)

        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        # inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])

        ### denorm ###
        # denorm_preds = np.stack([test_data.inverse_transform(pred) for pred in preds])
        # denorm_trues = np.stack([test_data.inverse_transform(true) for true in trues])

        ### denorm ###

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
        # mae, mse, rmse, mape, mspe, rse, corr = metric(denorm_preds, denorm_trues)

        print('mse:{}, mae:{}'.format(mse, mae))
        
        # # LPV可视化
        # if 'PQ' in self.args.model and hasattr(self.model, 'temporalQuery'):
        #     self._visualize_lpv(test_loader, folder_path)
        #     print('LPV可视化已保存到:', folder_path)
        f = open("result.txt", 'a')
        f.write(setting + "  \n")
        f.write('mse:{}, mae:{}'.format(mse, mae))
        f.write('\n')
        f.write('\n')
        f.close()

        # np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe,rse, corr]))
        # np.save(folder_path + 'pred.npy', preds)
        # np.save(folder_path + 'true.npy', trues)
        # np.save(folder_path + 'x.npy', inputx)
        return

    def predict(self, setting, load=False):
        pred_data, pred_loader = self._get_data(flag='pred')

        if load:
            path = os.path.join(self.args.checkpoints, setting)
            best_model_path = path + '/' + 'checkpoint.pth'
            self.model.load_state_dict(torch.load(best_model_path))

        preds = []

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_cycle) in enumerate(pred_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                batch_cycle = batch_cycle.int().to(self.device)

                # decoder input
                dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to(
                    batch_y.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                            outputs = self.model(batch_x, batch_cycle)
                        elif any(substr in self.args.model for substr in
                                 {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                        outputs = self.model(batch_x, batch_cycle)
                    elif any(substr in self.args.model for substr in {'Linear', 'MLP', 'SegRNN', 'TST', 'QuLTSF'}):
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                pred = outputs.detach().cpu().numpy()  # .squeeze()
                preds.append(pred)

        preds = np.array(preds)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        np.save(folder_path + 'real_prediction.npy', preds)

        return
    
    def _visualize_lpv(self, test_loader, folder_path, num_samples=1000):
        """Extract and visualize LPV representations"""
        print('Starting LPV representation extraction...')
        
        lpv_representations = []
        predicted_sequences = []  # 改为预测序列
        channel_indices = []
        
        self.model.eval()
        with torch.no_grad():
            sample_count = 0
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_cycle) in enumerate(test_loader):
                if sample_count >= num_samples:
                    break
                    
                batch_x = batch_x.float().to(self.device)
                batch_cycle = batch_cycle.int().to(self.device)
                
                # 获取模型预测输出
                if any(substr in self.args.model for substr in {'CycleNet', 'PQ'}):
                    outputs = self.model(batch_x, batch_cycle)
                elif any(substr in self.args.model for substr in {'Linear', 'MLP', 'SegRNN', 'TST','QuLTSF'}):
                    outputs = self.model(batch_x)
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                
                # 获取LPV表示
                if hasattr(self.model, 'temporalQuery') and hasattr(self.model, 'use_lpv') and self.model.use_lpv:
                    # 计算gather_index，使用pred_len获取预测长度的LPV
                    gather_index = (batch_cycle.view(-1, 1) + torch.arange(self.args.pred_len).view(1, -1).to(self.device)) % self.model.cycle_len
                    query_input = self.model.temporalQuery[gather_index]  # (batch, pred_len, enc_in)
                    
                    # 存储LPV表示和预测序列 (只取第一个样本的所有通道)
                    if sample_count == 0:  # 只处理第一个batch的第一个样本
                        # LPV表示 (pred_len, enc_in)
                        lpv_rep = query_input[0].cpu().numpy()
                        # 预测序列 (pred_len, enc_in)
                        pred_seq = outputs[0].detach().cpu().numpy()
                        
                        # 为每个通道存储数据
                        for ch in range(lpv_rep.shape[1]):
                            lpv_representations.append(lpv_rep[:, ch])  # (pred_len,)
                            predicted_sequences.append(pred_seq[:, ch])  # (pred_len,)
                            channel_indices.append(ch)
                            
                        sample_count = lpv_rep.shape[1]  # 设置为通道数
                        break  # 处理完第一个样本就退出
        print(len(predicted_sequences))                   
        if len(lpv_representations) == 0:
            print('Failed to extract LPV representations')
            return
            
        print(f'Successfully extracted {len(lpv_representations)} LPV representations')
        
        # 转换为numpy数组
        lpv_representations = np.array(lpv_representations)  # (num_samples*channels, pred_len)
        predicted_sequences = np.array(predicted_sequences)  # (num_samples*channels, pred_len)
        channel_indices = np.array(channel_indices)
        
        # 标准化
        scaler_lpv = StandardScaler()
        scaler_pred = StandardScaler()
        lpv_scaled = scaler_lpv.fit_transform(lpv_representations)
        pred_scaled = scaler_pred.fit_transform(predicted_sequences)
        
        # t-SNE dimensionality reduction
        print('Performing t-SNE dimensionality reduction...')
        n_samples = min(len(lpv_scaled), 1000)  # 限制样本数量以提高性能
        perplexity = min(30, n_samples // 4)  # 确保perplexity合理
        
        tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
        lpv_tsne = tsne.fit_transform(lpv_scaled[:n_samples])
        
        # 定义特殊通道 - 选择相邻的通道ID来验证相似性
        special_channels = [10, 550, 626]  # 连续的通道ID
        
        # 创建颜色映射
        colors = np.full(len(channel_indices[:n_samples]), 0.5)  # 默认灰色
        markers = np.full(len(channel_indices[:n_samples]), 'o')  # 默认圆形
        sizes = np.full(len(channel_indices[:n_samples]), 20)  # 默认大小
        
        # 为特殊通道分配不同颜色和标记 - 使用更鲜艳的颜色
        special_colors = [(0.2, 0.8, 0.9), (0.6, 0.4, 0.9), (0.9, 0.4, 0.4)]  # 青色、紫色、红色
        special_markers = ['o', 's', '^']
        other_channel_color = (0.7, 0.7, 0.7)  # 灰色
        
        for i, ch in enumerate(special_channels):
            mask = channel_indices[:n_samples] == ch
            if np.sum(mask) > 0:
                colors[mask] = i + 1  # 特殊颜色索引
                sizes[mask] = 50  # 更大的点
        
        # Visualization - 改为1x4水平布局
        fig = plt.figure(figsize=(24, 6))
        
        # LPV representation t-SNE (main plot) - 第一个子图
        ax1 = plt.subplot(1, 4, 1)
        # First draw normal channels
        normal_mask = np.isin(channel_indices[:n_samples], special_channels, invert=True)
        if np.sum(normal_mask) > 0:
            ax1.scatter(lpv_tsne[normal_mask, 0], lpv_tsne[normal_mask, 1], 
                       c=[other_channel_color], alpha=0.6, s=15, label='Other Channels')
        
        # Then draw special channels with more prominent display
        for i, ch in enumerate(special_channels):
            ch_mask = channel_indices[:n_samples] == ch
            if np.sum(ch_mask) > 0:
                ax1.scatter(lpv_tsne[ch_mask, 0], lpv_tsne[ch_mask, 1], 
                           c=[special_colors[i]], marker=special_markers[i], 
                           s=80, alpha=1.0, edgecolors='black', linewidth=1, label=f'Channel {ch}')
        
        # 计算相邻通道在t-SNE空间中的距离
        special_positions = []
        for ch in special_channels:
            ch_mask = channel_indices[:n_samples] == ch
            if np.sum(ch_mask) > 0:
                pos_idx = np.where(ch_mask)[0][0]
                special_positions.append(lpv_tsne[pos_idx])
        
        if len(special_positions) == 3:
            # 计算两两之间的欧氏距离
            dist_100_101 = np.linalg.norm(special_positions[0] - special_positions[1])
            dist_101_102 = np.linalg.norm(special_positions[1] - special_positions[2])
            dist_100_102 = np.linalg.norm(special_positions[0] - special_positions[2])
            
            # 在图上添加距离信息
            ax1.text(0.02, 0.98, f'Channel Distances:\nCh100-101: {dist_100_101:.2f}\nCh101-102: {dist_101_102:.2f}\nCh100-102: {dist_100_102:.2f}', 
                     transform=ax1.transAxes, fontsize=10, verticalalignment='top',
                     bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
         
        ax1.set_title('t-SNE: Adjacent Channels Spatial Distribution', fontsize=14, fontweight='bold')
        ax1.set_xlabel('t-SNE Dimension 1')
        ax1.set_ylabel('t-SNE Dimension 2')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Draw predicted sequences for special channels - 水平排列的后三个子图
        for idx, ch in enumerate(special_channels):
            ch_mask = channel_indices[:n_samples] == ch
            if np.sum(ch_mask) > 0:
                sample_idx = np.where(ch_mask)[0][0]
                
                ax = plt.subplot(1, 4, idx + 2)  # 第2、3、4个子图
                
                # 确保time_steps和sequence_data长度一致
                sequence_data = predicted_sequences[sample_idx]
                sequence_length = len(sequence_data)
                time_steps = range(sequence_length)
                    
                ax.plot(time_steps, sequence_data, color=special_colors[idx], 
                       linewidth=2.5, label=f'Channel {ch} Prediction', alpha=0.9)
                ax.set_title(f'Channel {ch} Predicted Sequence', fontsize=12, fontweight='bold')
                ax.set_xlabel('Prediction Time Steps', fontsize=10)
                ax.set_ylabel('Predicted Value', fontsize=10)
                ax.legend(fontsize=10)
                ax.grid(True, alpha=0.3)
                
                # 设置更好的坐标轴范围
                ax.set_xlim(0, sequence_length-1)
                y_min, y_max = np.min(sequence_data), np.max(sequence_data)
                y_range = y_max - y_min
                ax.set_ylim(y_min - 0.1*y_range, y_max + 0.1*y_range)
        
        plt.tight_layout()
        plt.savefig(os.path.join(folder_path, 'lpv_visualization.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # 保存数据
        np.save(os.path.join(folder_path, 'lpv_representations.npy'), lpv_representations)
        np.save(os.path.join(folder_path, 'predicted_sequences.npy'), predicted_sequences)
        np.save(os.path.join(folder_path, 'channel_indices.npy'), channel_indices)
        
        print(f'LPV visualization completed, processed {len(lpv_representations)} samples')
