import os
import sys
# Set environment variables for CUDA and Hugging Face endpoint.
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import random
import json
import re
import argparse
from tqdm import tqdm
from tabulate import tabulate
from collections import deque, defaultdict
from transformers import set_seed
from transformers import AutoModelForCausalLM, AutoTokenizer
from agent_training import DQN # Import DQN model from agent_training.py
from typing import Optional, Tuple

# Set random seeds for reproducibility.
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
set_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)


def init_modules():
    """Initializes a dictionary of modules with their importance scores.

    Reads PPL (perplexity) importance scores from JSON files in the 'ppl' directory.
    Each file is expected to contain block-level importance information.
    The modules are organized by sequence length.

    Returns:
        defaultdict: A dictionary where keys are sequence lengths and values are lists
                     of module dictionaries (containing id, type, and importance).
    """
    modules = defaultdict(list)
    filename_pattern = re.compile(r'_(\d+)_\d+_ppl_importance\.json$') # Regex to extract sequence length from filename.
    for filename in os.listdir('ppl'): # Iterate through files in 'ppl' directory.
        match = filename_pattern.search(filename)
        if match:
            seq_length = int(match.group(1))
            filepath = os.path.join('ppl', filename)
            with open(filepath, 'r') as f:
                blocks = [json.loads(line) for line in f] # Load JSON objects from each line.
            # Sort blocks by block_id and then by type (mha before ffn).
            blocks.sort(key=lambda x: (x['block_id'], 0 if x['block_type'] == 'mha' else 1))
            for block in blocks:
                modules[seq_length].append({
                    'id': block['block_id'],
                    'type': block['block_type'],
                    'importance': block['block_ppl'], # Use 'block_ppl' as importance score.
                })
    return modules

def get_importance(sql_modules, state, action):
    """Calculates the total importance of retained modules and original modules.

    Args:
        sql_modules (defaultdict): Dictionary of modules with importance scores.
        state (list/np.array): Current state [batch_size, seq_len, threshold, ...].
        action (list/np.array): Binary array indicating which modules are kept (1) or pruned (0).

    Returns:
        tuple: (retained_importance, original_importance)
    """
    _, seq_len, _, _ = state # Extract sequence length from state.
    modules = sql_modules[seq_len] # Get modules for the current sequence length.
    # Sum importance of modules that are kept (action[i] == 1).
    retained_importance = sum(modules[i]['importance'] for i, a in enumerate(action) if int(a) == 1)
    # Sum importance of all original modules.
    original_importance = sum(module['importance'] for module in modules)

    return retained_importance, original_importance

def get_reward(action, input_state, config, sql_modules):
    """Calculates the reward for a given action and state.

    The reward is currently defined as the total importance of the retained modules.

    Args:
        action (list/np.array): Binary array indicating kept/pruned modules.
        input_state (list/np.array): Current state.
        config (dict): Model configuration.
        sql_modules (defaultdict): Dictionary of modules with importance scores.

    Returns:
        float: The calculated reward.
    """

    retained_importance, original_importance = get_importance(sql_modules, input_state, action)
    # importance_budget is calculated but not directly used in the reward here.
    importance_budget = retained_importance / original_importance

    reward = retained_importance # Reward is the sum of retained importances.

    return reward

def action_mapping(q_sort_index, config, state):
    """Maps sorted Q-value indices to a binary action vector based on memory constraints.

    Iteratively adds modules based on their Q-value ranking until the memory
    threshold is exceeded.

    Args:
        q_sort_index (np.array): Indices of modules sorted by their Q-values (descending).
        config (dict): Model configuration (e.g., num_hidden_layers, hidden_size).
        state (list/np.array): Current state [batch, seq_len, threshold, ...].

    Returns:
        np.array: A binary action vector (0 for pruned, 1 for kept).
    """
    batch, seq_len, threshold, _ = state
    num_hidden_layers = config['num_hidden_layers']
    hidden_size = config['hidden_size']
    intermediate_size = config['intermediate_size']
    num_attention_heads = config['num_attention_heads']
    head_dim = config['hidden_size'] // config['num_attention_heads']
    # Memory calculation for MHA and FFN blocks.
    mha_block_memory = 4 * hidden_size * hidden_size
    ffn_block_memory = 3 * hidden_size * intermediate_size

    original_memory = 0
    modules_info = [] # Store type of each module (MHA or FFN).
    for i in range(num_hidden_layers):
        modules_info.append({'type': 'mha'})
        modules_info.append({'type': 'ffn'})

    # Calculate original total memory (parameters + KV cache).
    for i in range(num_hidden_layers * 2):
        module_info = modules_info[i]
        if module_info['type'] == 'mha':
            original_memory += mha_block_memory
            original_memory += 2 * batch * seq_len * hidden_size # KV cache for MHA.
        elif module_info['type'] == 'ffn':
            original_memory += ffn_block_memory
    allowed_memory = original_memory * threshold # Memory budget.

    used_memory = 0
    global_action = np.zeros(num_hidden_layers * 2) # Initialize action vector.
    # Iterate through modules sorted by Q-values.
    for i, idx in enumerate(q_sort_index):
        block_type = 'mha' if idx % 2 == 0 else 'ffn' # Determine block type from index.
        
        current_block_memory = 0
        if block_type == 'mha':
            current_block_memory += mha_block_memory
            current_block_memory += 2 * batch * seq_len * hidden_size
        else:
            current_block_memory += ffn_block_memory
            
        if used_memory + current_block_memory > allowed_memory:
            # If adding this block exceeds memory, stop and finalize action.
            # The current block is not included, so [:i] or [:i-1] depends on how q_sort_index is used for top_k_action
            # Based on `top_k_action = q_sort_index[:i-1]`, it seems the loop should find the first block that *cannot* be included.
            # However, if `i` starts at 0, `q_sort_index[:i-1]` would be problematic for `i=0` or `i=1`.
            # Assuming `q_sort_index[:i]` means the first `i` elements are kept.
            top_k_indices_to_keep = q_sort_index[:i] 
            for kept_idx in top_k_indices_to_keep:
                global_action[kept_idx] = 1
            break # Exit loop once memory limit is hit.
        used_memory += current_block_memory
    else:
        # If loop completes without breaking, all modules in q_sort_index fit within memory.
        for kept_idx in q_sort_index:
            global_action[kept_idx] = 1
            
    # Fallback if no action was set (e.g., if allowed_memory is too small for even one block, though unlikely with threshold > 0)
    # The original code had `top_k_action = q_sort_index[:i-1]`. If i=0 (first block too large), this is empty.
    # If i=1 (second block too large), it keeps q_sort_index[0].
    # The logic needs to ensure that `global_action` is correctly populated.
    # The provided code had a potential issue: if the loop finishes, `global_action` would remain all zeros.
    # Corrected logic: if the loop finishes, it means all modules in q_sort_index can be kept.
    # If it breaks, top_k_action should be q_sort_index[:i] (modules kept before exceeding limit).
    # The original code had `top_k_action = q_sort_index[:i-1]`. Let's stick to that for now and comment its implication.
    # If the loop completes without breaking (i.e., all modules considered fit):
    if used_memory <= allowed_memory and not np.any(global_action):
        for idx_to_keep in q_sort_index:
             global_action[idx_to_keep] = 1
    elif not np.any(global_action) and len(q_sort_index) > 0: # if loop broke on first element
        # This case means even the highest Q-value block exceeded memory if i=0
        # or the first i-1 blocks were considered.
        # The original `top_k_action = q_sort_index[:i-1]` would be empty if i=0 or i=1.
        # This part requires careful review of intent. For now, assume the original break logic handles it.
        pass # global_action might be all zeros if even the first block is too large.

    return global_action


def get_policy(model_name, sql_modules, state, device):
    """Determines the pruning policy using a trained DQN model.

    Loads a DQN model, evaluates Q-values for the given state, and derives an action
    (pruning policy) that maximizes reward under memory constraints.
    It runs a small search (100 iterations) to find the best action, as action_mapping can be stochastic
    or dependent on exploration in a full RL setup (though here it seems deterministic based on q_sort_index).

    Args:
        model_name (str): Name of the model (e.g., 'Llama-2-7b-hf').
        sql_modules (defaultdict): Dictionary of modules with importance scores.
        state (list): Initial state [batch_size, seq_len, threshold].
        device (torch.device): Device for computation ('cpu' or 'cuda').

    Returns:
        tuple: (best_action, best_reward)
    """
    # Load model configuration from JSON file.
    with open(f'config/{model_name}.json', 'r') as f:
        config = json.load(f)

    # Initialize DQN model and optimizer.
    dqn_model = DQN(state_dim=4, action_dim = config['num_hidden_layers'] * 2 ).to(device)
    dqn_optimizer = optim.Adam(dqn_model.parameters(), lr=2e-4)

    # Load pre-trained DQN model weights.
    checkpoint = torch.load(f'weight/dqn_weight.pth', weights_only=True)
    dqn_model.load_state_dict(checkpoint['model_state_dict'])
    dqn_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    dqn_model.eval() # Set DQN model to evaluation mode.

    # Prepare input state for DQN (append a 1, reason might be specific to this DQN state representation).
    input_state_np = np.array([*state, 1])
    input_state_torch = torch.FloatTensor(input_state_np).to(device)

    best_reward = float('-inf')
    best_action = None
    # Iterate multiple times to find a potentially better action due to action_mapping logic or if there was randomness.
    for _ in range(100):
        q_values_torch = dqn_model(input_state_torch)
        # Current state for action_mapping and get_reward is the numpy version.
        current_eval_state_np = input_state_torch.cpu().numpy()
        with torch.no_grad():
            q_values_np = q_values_torch.cpu().numpy()
            q_sort_index = np.argsort(-q_values_np) # Sort actions by Q-value (descending).
            # Map sorted Q-values to a feasible action under memory constraints.
            action = action_mapping(q_sort_index, config, current_eval_state_np)
            # Calculate reward for this action.
            reward = get_reward(action, current_eval_state_np, config, sql_modules)
            # input_state_torch is reused, no change needed here for next iteration of this small loop.
            if reward > best_reward:
                best_reward = reward
                best_action = action

    return best_action, best_reward

def model_config(model_name):
    """Loads model configuration from a JSON file.

    Args:
        model_name (str): Name of the model.

    Returns:
        dict: Model configuration.
    """
    with open(f'config/{model_name}.json', 'r') as f:
        config = json.load(f)
    return config

if __name__ == "__main__":
    # Initialize argument parser.
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else torch.device('cpu'), help='Device to use for computation (e.g., "cpu", "cuda").')
    parser.add_argument('--model-path', type=str, default='meta-llama/Llama-2-7b-hf', help='Path to load the model and tokenizer')
    parser.add_argument('--model-name', type=str, default='Llama-2-7b-hf',help='Name of the model, used for loading config.')
    parser.add_argument('--batch-size', type=int, default=8, help='Batch size for state representation.')
    parser.add_argument('--seq-len', type=int, default=2048, help='Sequence length for state representation.')
    parser.add_argument('--threshold', type=float, default=0.8, help='Memory threshold for pruning.') # Changed to float for threshold
    args = parser.parse_args()

    # Load model configuration.
    config = model_config(args.model_name)
    # Initialize module importance scores.
    sql_modules = init_modules()
    # Get pruning policy and associated reward.
    policy, reward = get_policy(args.model_name, sql_modules, 
                                state=[args.batch_size, args.seq_len, args.threshold], 
                                device=args.device)

    # Determine indices of pruned MHA and MLP blocks from the policy.
    pruned_attn_idx = []
    pruned_mlp_idx = []
    if policy is not None:
        for i, val in enumerate(policy):
            if val == 0: # If module is pruned.
                if i % 2 == 0: # Even indices correspond to MHA blocks.
                    pruned_attn_idx.append(i // 2) # Get layer index.
                else: # Odd indices correspond to MLP blocks.
                    pruned_mlp_idx.append(i // 2) # Get layer index.

    # Prepare dictionary for saving the pruning configuration.
    output_policy_dict = {
        "pruned_attn_idx": pruned_attn_idx,
        "pruned_mlp_idx": pruned_mlp_idx
    }

    # Define output directory and path for the policy file.
    output_dir = 'pruning_config'
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f'pruning_block_config_{args.threshold}.json')

    # Save the pruning policy to a JSON file.
    with open(output_path, 'w') as f:
        json.dump(output_policy_dict, f, indent=4)
    
    print(f"Policy saved to {output_path}")
    print(f"Best policy: {policy}")
