import os
import sys
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import json
import re
import itertools
from tqdm import tqdm
from collections import deque, defaultdict
from transformers import set_seed
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
set_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)


# DQN (Deep Q-Network) used for learning pruning decisions.
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# PruningEnvironment simulates the environment for pruning decisions, including calculations for memory and FLOPs.
class PruningEnvironment:
    def __init__(self, config):
        self.layers = config['num_hidden_layers']
        self.hidden_size = config['hidden_size']
        self.intermediate_size = config['intermediate_size']
        self.head_dim = config['head_dim']
        self.num_heads = config['num_attention_heads']
        self.mha_block_memory = 4 * self.hidden_size * self.hidden_size
        self.ffn_block_memory = 3 * self.hidden_size * self.intermediate_size
        self.modules = self.init_modules()

    @staticmethod
    def init_modules():
        modules = defaultdict(list)
        filename_pattern = re.compile(r'_(\d+)_\d+_ppl_importance\.json$')
        for filename in os.listdir('ppl'):
            match = filename_pattern.search(filename)
            if match:
                seq_length = int(match.group(1))
                filepath = os.path.join('ppl', filename)
                with open(filepath, 'r') as f:
                    blocks = [json.loads(line) for line in f]
                blocks.sort(key=lambda x: (x['block_id'], 0 if x['block_type'] == 'mha' else 1))
                for block in blocks:
                    modules[seq_length].append({
                        'id': block['block_id'],
                        'type': block['block_type'],
                        'importance': block['block_ppl'],
                    })
        return modules
    @staticmethod
    def reset_state():
        # Reset state by shuffling predefined combinations of batch sizes, sequence lengths, and thresholds.
        if not hasattr(PruningEnvironment.reset_state, 'combinations'):
            batch_sizes = [1, 2, 4, 8, 16]
            seq_lens = [128, 256, 512, 1024, 2048, 4096]
            thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
            PruningEnvironment.reset_state.combinations = list(itertools.product(batch_sizes, seq_lens, thresholds))
            random.shuffle(PruningEnvironment.reset_state.combinations)
            PruningEnvironment.reset_state.index = 0
        else:
            if PruningEnvironment.reset_state.index == 0:
                random.shuffle(PruningEnvironment.reset_state.combinations)
        
        combo = PruningEnvironment.reset_state.combinations[PruningEnvironment.reset_state.index]
        PruningEnvironment.reset_state.index = (PruningEnvironment.reset_state.index + 1) % len(PruningEnvironment.reset_state.combinations)
        
        return np.array([combo[0], combo[1], combo[2], 1])

    def calculate_memory(self, state, action):
        # Calculate the total memory usage based on the selected (retained) modules.
        batch, seq_len, _, _ = state
        param_mem = 0
        kv_cache = 0
        modules = self.modules[seq_len]
        for i, s in enumerate(action):
            if s == 1:
                module = modules[i]
                if module['type'] == 'mha':
                    param_mem += self.mha_block_memory
                    kv_cache += 2 * batch * seq_len * self.hidden_size
                elif module['type'] == 'ffn':
                    param_mem += self.ffn_block_memory
        return param_mem + kv_cache

    def calculate_flops(self, state, action):
        # Calculate the number of floating point operations (FLOPs) for the given action.
        batch, seq_len, _, _ = state
        used_flops = 0
        modules = self.modules[seq_len]
        for i, s in enumerate(action):
            if s == 1:
                module = modules[i]
                if module['type'] == 'mha':
                    used_flops += 4 * batch * seq_len * self.hidden_size**2 + 2 * batch**2 * self.num_heads**2 * seq_len**2 * self.head_dim**2
                elif module['type'] == 'ffn':
                    used_flops += 3 * batch * seq_len * self.intermediate_size * self.hidden_size + batch * seq_len * self.intermediate_size
        return used_flops

    def calculate_original_memory(self, state):
        # Calculate the total memory usage without pruning.
        batch, seq_len, _, _ = state
        original_memory = 0
        modules = self.modules[seq_len]
        for i in range(self.layers * 2):
            module = modules[i]
            if module['type'] == 'mha':
                original_memory += self.mha_block_memory
                original_memory += 2 * batch * seq_len * self.hidden_size
            elif module['type'] == 'ffn':
                original_memory += self.ffn_block_memory
        return original_memory

    def calculate_original_flops(self, state):
        # Calculate the total FLOPs without pruning.
        batch, seq_len, _, _ = state
        original_flops = 0
        modules = self.modules[seq_len]
        for i in range(self.layers * 2):
            module = modules[i]
            if module['type'] == 'mha':
                original_flops += 4 * batch * seq_len * self.hidden_size**2 + 2 * batch**2 * self.num_heads**2 * seq_len**2 * self.head_dim**2
            elif module['type'] == 'ffn':
                original_flops += 3 * batch * seq_len * self.intermediate_size * self.hidden_size + batch * seq_len * self.intermediate_size
        return original_flops

    def get_importance(self, state, action):
        # Computes the importance of retained modules vs. original modules.
        _, seq_len, _, _ = state
        modules = self.modules[seq_len]
        retained_importance = sum(modules[i]['importance'] for i, a in enumerate(action) if int(a) == 1)
        original_importance = sum(module['importance'] for module in modules)

        return retained_importance, original_importance

    @staticmethod
    def get_next_state(state, importance_ratio):
        # Generates the next state based on the current memory importance ratio.
        batch, seq_len, threshold, _ = state
        next_state = np.array([batch, seq_len, threshold, importance_ratio])
        return next_state

    def step(self, state, action):
        # Applies the action to the current state, computes reward, and returns the next state.
        retained_importance, original_importance = self.get_importance(state, action)
        importance_budget = retained_importance / original_importance
        reward = retained_importance
        next_state = self.get_next_state(state, importance_budget)    

        return reward, next_state


# ReplayBuffer stores experiences for training the DQN.
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity              
        self.buffer = []                      
        self.experience_set = set()           
        self.removed_buffer = []              

    def _experience_to_hashable(self, state, action, reward, next_state):
        # Convert experience tuple to a hashable form for deduplication.
        if isinstance(state, np.ndarray):
            state_tuple = tuple(state.flatten().tolist())
        else:
            state_tuple = tuple(state)
        if isinstance(next_state, np.ndarray):
            next_state_tuple = tuple(next_state.flatten().tolist())
        else:
            next_state_tuple = tuple(next_state)
        return (state_tuple, action, reward, next_state_tuple)

    def add(self, state, action, reward, next_state):
        experience = (state, action, reward, next_state)
        h_experience = self._experience_to_hashable(state, action, reward, next_state)
        if h_experience in self.experience_set:
            return False

        if len(self.buffer) >= self.capacity:
            removed_exp = self.buffer.pop(0)
            h_removed = self._experience_to_hashable(*removed_exp)
            self.experience_set.remove(h_removed)
            self.removed_buffer.append(removed_exp)

        self.buffer.append(experience)
        self.experience_set.add(h_experience)
        return True

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        batch = [self.buffer[i] for i in indices]
        states, actions, rewards, next_states = zip(*batch)
        return states, actions, rewards, next_states

    def size(self):
        return len(self.buffer)


# DQNAgent uses the DQN model to learn and update pruning strategies.
class DQNAgent:
    def __init__(self, config, epochs, episodes, sub_episodes, device):
        self.layers = config['num_hidden_layers']
        self.device = device
        self.env = PruningEnvironment(config)
        self.state_dim = 4
        self.action_dim = self.layers * 2
        self.episodes = episodes
        self.epochs = epochs
        self.sub_episodes = sub_episodes
        self.model = DQN(self.state_dim, self.action_dim).to(device)
        self.target_model = DQN(self.state_dim, self.action_dim).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=2e-3)
        self.reply_buffer = ReplayBuffer(10000)
        self.minimal_size = 500
        self.batch_size = 32
        self.gamma = 0.98
        self.epsilon = 0.1
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.count = 0
        self.target_update_freq = 20

    def get_action(self, state):
        # Epsilon-greedy action selection using DQN.
        if random.random() < self.epsilon:
            q_values = np.random.randint(0, self.action_dim, size=self.action_dim)
            q_sort_index = np.argsort(-q_values)
            global_action = self.action_mapping(q_sort_index, state)
            best_action = q_sort_index[0]

        else:
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).to(self.device)
                q_values = self.model(state_tensor)
                state_np = state_tensor.cpu().numpy()
                q_values_np = q_values.cpu().numpy()
                q_sort_index = np.argsort(-q_values_np)
                global_action = self.action_mapping(q_sort_index, state_np)
                best_action = q_sort_index[0]
     
        return best_action, global_action
    def action_mapping(self, q_sort_index, state):
        # Map Q-value ranking to a pruning decision based on memory constraints.
        batch, seq_len, threshold, _ = state
        original_memory = self.env.calculate_original_memory(state)
        allowed_memory = original_memory * threshold
        
        used_memory = 0
        for i, idx in enumerate(q_sort_index):
            block_type = 'mha' if idx % 2 == 0 else 'ffn'
            
            if block_type == 'mha':
                used_memory += self.env.mha_block_memory
                used_memory += 2 * batch * seq_len * self.env.hidden_size
            else:
                used_memory += self.env.ffn_block_memory
                
            if used_memory > allowed_memory:
                top_k_action = q_sort_index[:i-1]
                global_action = np.zeros(self.layers * 2)
                for id in top_k_action:
                    global_action[id] = 1
                break
        
        return global_action

    def update_model(self):
        # Update the DQN model by sampling from the replay buffer and minimizing MSE loss.
        batch_state, batch_action, batch_reward, batch_next_state = self.reply_buffer.sample(self.batch_size)
        transition_dict = {
            'state': batch_state,
            'action': batch_action,
            'next_state': batch_next_state,
            'reward': batch_reward
        }
        states = torch.tensor(np.array(transition_dict['state']), dtype=torch.float).to(self.device)
        actions = torch.tensor(np.array(transition_dict['action'])).view(-1, 1).to(self.device)
        rewards = torch.tensor(np.array(transition_dict['reward']), dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(np.array(transition_dict['next_state']), dtype=torch.float).to(self.device)

        q_values = self.model(states).gather(1, actions).to(self.device)
        max_next_q_values = self.target_model(next_states).max(1)[0].view(-1, 1).to(self.device)
        q_targets = rewards + self.gamma * max_next_q_values.to(self.device)
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))

        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()

        if self.count % self.target_update_freq == 0:
            self.target_model.load_state_dict(self.model.state_dict())
        self.count += 1

    def train(self):
        # Train the DQNAgent over multiple epochs and episodes.
        reward_list = []
        for i in range(self.epochs):
            with (tqdm(total=int(self.episodes), desc='Iteration %d' % i) as pbar):
                for i_episode in range(int(self.episodes)):
                    state = self.env.reset_state()
                    episode_reward = 0
                    sub_episode_count = 0
                    while sub_episode_count <= self.sub_episodes:
                        best_action, global_action = self.get_action(state)
                        reward, next_state = self.env.step(state, global_action)
                        self.reply_buffer.add(state, best_action, reward, next_state)
                        state = next_state
                        episode_reward += reward

                        if self.reply_buffer.size() > self.minimal_size:
                            self.update_model()
                        sub_episode_count += 1

                        if self.epsilon > self.epsilon_min:
                            self.epsilon *= self.epsilon_decay

                    reward_list.append(episode_reward)
                    if (i_episode + 1) % 5 == 0:
                        pbar.set_postfix({
                            'episode': '%d' % (self.episodes * i + i_episode + 1),
                            'return': '%.3f' % np.mean(reward_list[-10:])
                        })
                    pbar.update(1)
        self.save_model()

    def save_model(self):
        # Save the trained DQN model weights and optimizer state.
        save_path = 'weight/dqn_weight.pth'
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, save_path)


def model_config(model_name):
    # Load model configuration from a JSON file.
    with open(f'config/{model_name}.json', 'r') as f:
        config = json.load(f)
    return config


if __name__ == '__main__':
    # Load configuration and initialize the DQNAgent, then start training.
    llama_config = model_config('Llama-2-7b-hf')
    agent = DQNAgent(llama_config, epochs=50, episodes=50, sub_episodes=50, device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    agent.train()
