import numpy as np
import torch
import torch.nn as nn
import pickle
import os
import re

class Memory:
    def __init__(self, num_machines, capacity=100000):
        self.capacity = capacity
        self.num_machines = num_machines
        self.memory = []
        self.position = 0
        self.normalization_params = None
        
    def normalize(self, normalization_params=None):
        if len(self.memory) == 0:
            return
            
        # 如果提供了归一化参数，直接使用
        if normalization_params is not None:
            self.normalization_params = normalization_params
            normalized_memory = []
            for features, target in self.memory:
                normalized_target = (target - normalization_params['target_mean']) / normalization_params['target_std']
                normalized_memory.append((features, normalized_target))
            self.memory = normalized_memory
            return
            
        # 否则计算新的归一化参数
        all_targets = np.array([item[1] for item in self.memory])
        
        # 计算目标值的均值和标准差
        target_mean = np.mean(all_targets)
        target_std = np.std(all_targets)
        if target_std == 0:
            target_std = 1
        
        # 保存归一化参数
        self.normalization_params = {
            'target_mean': target_mean,
            'target_std': target_std
        }
        
        # 归一化数据，只归一化目标值
        normalized_memory = []
        for features, target in self.memory:
            # 归一化目标值
            normalized_target = (target - target_mean) / target_std
            normalized_memory.append((features, normalized_target))
            
        self.memory = normalized_memory
        
    def denormalize(self, normalized_value):
        if self.normalization_params is None:
            return normalized_value
            
        return normalized_value * self.normalization_params['target_std'] + self.normalization_params['target_mean']
        
    def normalize_value(self, original_value):
        return (original_value - self.normalization_params['target_mean']) / self.normalization_params['target_std']
        
    def denormalize_value(self, normalized_value):
        return normalized_value * self.normalization_params['target_std'] + self.normalization_params['target_mean']
        
    def push(self, jobs_features, utilization):
        features = np.array(jobs_features)
        
        # 存储数据
        if len(self.memory) < self.capacity:
            self.memory.append((features, utilization))
        else:
            # 随机选择要删除的位置
            remove_idx = np.random.randint(0, 100)
            # 删除选定位置的数据
            self.memory.pop(remove_idx)
            # 添加新数据
            self.memory.append((features, utilization))
        
    def sample(self, batch_size):
        indices = np.random.choice(len(self.memory), batch_size)
        features, utilizations = zip(*[self.memory[i] for i in indices])
        return torch.FloatTensor(features), torch.FloatTensor(utilizations)
    
    def __len__(self):
        return len(self.memory)

    @classmethod
    def _clean_filename(cls, filename):
        """
        清理文件名，移除非法字符
        """
        illegal_chars = r'[\\/:*?"<>|]'
        clean_name = re.sub(illegal_chars, '_', filename)
        return clean_name

    def save(self, filepath):
        dirname = os.path.dirname(filepath)
        filename = os.path.basename(filepath)
        clean_filename = self._clean_filename(filename)
        clean_filepath = os.path.join(dirname, clean_filename) if dirname else clean_filename
        
        if not os.path.exists(clean_filepath):
            open(clean_filepath, 'w').close()
            
        with open(clean_filepath, 'wb') as f:
            memory_data = [(features.tolist() if isinstance(features, np.ndarray) else features, 
                          utilization) for features, utilization in self.memory]
            pickle.dump({
                'capacity': self.capacity,
                'num_machines': self.num_machines,
                'memory': memory_data
            }, f)

    def split_memory(self, test_ratio=0.1):
        test_size = int(len(self.memory) * test_ratio)
        test_indices = set(np.random.choice(len(self.memory), test_size, replace=False))
        train_memory = Memory(num_machines=self.num_machines, capacity=len(self.memory) - test_size)
        test_memory = Memory(num_machines=self.num_machines, capacity=test_size)
        for i, (features, target) in enumerate(self.memory):
            if i in test_indices:
                test_memory.memory.append((features, target))
            else:
                train_memory.memory.append((features, target))

        return train_memory, test_memory
            
    @classmethod
    def load(cls, filepath):
        dirname = os.path.dirname(filepath)
        filename = os.path.basename(filepath)
        clean_filename = cls._clean_filename(filename)
        clean_filepath = os.path.join(dirname, clean_filename) if dirname else clean_filename
        with open(clean_filepath, 'rb') as f:
            data = pickle.load(f)
        memory = cls(num_machines=data['num_machines'], capacity=data['capacity'])
        memory.memory = [(np.array(features) if isinstance(features, list) else features, 
                         utilization) for features, utilization in data['memory']]
        return memory 

    def augment(self, times):
        """
        对memory中的每条数据增强times次
        """
        if len(self.memory) == 0:
            return
        sample_feature = self.memory[0][0]
        if not (isinstance(sample_feature, np.ndarray) and sample_feature.ndim == 2):
            raise ValueError("features必须为二维numpy数组")
        n_machines = sample_feature.shape[1] // 3
        new_data = []
        for _ in range(times):
            mapping = np.random.permutation(n_machines)
            for features, target in self.memory:
                new_features = features.copy()
                for i in range(new_features.shape[0]):
                    indices_3k1 = np.arange(1, new_features.shape[1], 3)
                    for idx in indices_3k1:
                        val = new_features[i, idx]
                        if 1 <= val <= n_machines:
                            new_features[i, idx] = mapping[int(val)-1] + 1
                new_data.append((new_features, target))
        self.memory.extend(new_data)

if __name__ == "__main__":
     memory = Memory.load('TrainData1000_10_10_100')
     memory.augment(10)
     memory.save('TrainData_100_100_100_100_augmented')