import pickle
import matplotlib.pyplot as plt
import numpy as np
import random
from copy import deepcopy
from collections import OrderedDict

import torch
from torch.utils.data import Dataset


import logging

def crop_data(driving_cycle, chunk_size):

    chunks = []
    for i in range(0, len(driving_cycle), chunk_size):
        chunk = driving_cycle[i:i+chunk_size]
        if len(chunk) == chunk_size:
            chunks.append(chunk)

    return chunks

def crop_data_repeated(driving_cycle, chunk_size, slide_length = 100):
    chunks = []
    for i in range(0, len(driving_cycle) - chunk_size + 1, slide_length):
        chunk = driving_cycle[i:i+chunk_size]
        chunks.append(chunk)
    return chunks


def symmetric_padding(data, window_size=21):
    """
    Apply symmetric padding to the data for edge handling.

    Parameters:
        data (array_like): Input data.
        window_size (int): Size of the moving average window.

    Returns:
        array_like: Padded data.
    """
    # Calculate the number of points to pad on each side
    pad_width = window_size // 2
    
    # Pad the data symmetrically
    padded_data = np.pad(data, (pad_width, pad_width), mode='reflect')
    
    return padded_data

def moving_average(data, window_size=21):
    """
    Apply moving average smoothing to the given data.

    Parameters:
        data (array_like): Input data to be smoothed.
        window_size (int): Size of the moving average window.

    Returns:
        array_like: Smoothed data.
    """
    # Apply symmetric padding to the data
    padded_data = symmetric_padding(data, window_size)
    
    # Define the kernel for the moving average
    kernel = np.ones(window_size) / window_size
    
    # Apply the moving average filter
    smoothed_data = np.convolve(padded_data, kernel, mode='valid')
    
    return smoothed_data

def shuffle_list(input_list):
    """
    Return a shuffled version of the input list.

    Parameters:
        input_list (list): Input list to be shuffled.

    Returns:
        list: Shuffled list.
    """
    shuffled_list = input_list[:]  # Make a copy of the input list
    random.shuffle(shuffled_list)  # Shuffle the copy
    return shuffled_list

def draw_figure(data, label='label'):

    plt.plot(data, label=f'{label}')



def show_figure():

    plt.title('vs vs. time for Each Trip')
    plt.xlabel('Time')
    plt.ylabel('vs')
    plt.legend()
    plt.show()

def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    
    logging.basicConfig(
        level=logging.INFO,
        format='[\033[34m%(asctime)s\033[0m] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
    )
    logger = logging.getLogger(__name__)

    return logger

def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag

def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)

def if_nan(dataset):
    for i, data in enumerate(dataset):
        x, _ = data  # 假设数据集中的每个数据项都是一个元组，其中 x 是要检查的张量
        if torch.isnan(x).any():
            print(f"NaN found in sample {i}")

class MyDataset(Dataset):
    def __init__(self, data=None, labels=None):

        self.data = [self.map_to_range(torch.tensor(item, dtype=torch.float32)).unsqueeze(0) for item in data] if data is not None else []
        self.labels = [torch.tensor(label, dtype=torch.int) for label in labels] if labels is not None else []


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]

        return x, y

    def add_sample(self, data, label):
        self.data.append(data)

        self.labels.append(label)


    def map_to_range(self, tensor):
        # 找到最大和最小值
        min_val = tensor.min().item()
        max_val = tensor.max().item()

        # 将值映射到 (-1, 1) 范围
        mapped_tensor = 2 * (tensor - min_val) / (max_val - min_val) - 1

        return mapped_tensor

    def modify_labels(self, new_labels):
        # 修改数据集中的标签
        self.labels = new_labels