import argparse
import json
import numpy as np
import os
import pandas as pd
import random
import time
import torch
import tqdm
from client import vllmClientModel
from config import (
    GPQA_DIR, GPQA_MAX_LEN, GPQA_NUM_CHAINS,
    MATH_DIR, MATH_MAX_LEN, MATH_NUM_CHAINS,
    MMLU_DIR, MMLU_MAX_LEN, MMLU_NUM_CHAINS,
    GSM8K_DIR, GSM8K_MAX_LEN, GSM8K_NUM_CHAINS,
    OLYMPIAD_DIR, OLYMPIAD_MAX_LEN, OLYMPIAD_NUM_CHAINS,
    MODEL_IDS,
)
from config import *
from evaluator import extract_answer, extract_first_boxed_answer
from math_answer import MathAnswer
from sklearn.model_selection import train_test_split
from utils import process_math_id

import argparse
import os
import pandas as pd
import torch
import numpy as np
from config import DSET_TO_DIR, MODEL_IDS
from create_tensor import create_tensor_for_dataset
from predictor import ReasoningModel, ReasoningMLPModel
from torch.utils.data import DataLoader, TensorDataset, Dataset
from utils import convert_model_setting_to_str, convert_data_setting_to_str
from sklearn.metrics import precision_score, recall_score, f1_score, average_precision_score, accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import _LRScheduler, StepLR
import math
import gc
import torch.nn as nn


class SimpleMLPModel(nn.Module):
    
    def __init__(self, tensor_dim, dropout=0.15, hidden_layers=None):
        super().__init__()
        
        if hidden_layers is None:
            hidden_layers = [tensor_dim // 2, tensor_dim // 4]
        
        self.ln = nn.LayerNorm(tensor_dim)
        self.dropout = nn.Dropout(dropout)
        
        layers = []
        prev_dim = tensor_dim
        
        for hidden_dim in hidden_layers:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1, bias=True))
        self.mlp = nn.Sequential(*layers)
        
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, intermediate_emb, seq_mask):
        intermediate_emb = self.ln(intermediate_emb)
        if intermediate_emb.dim() == 3 and intermediate_emb.size(1) == 1:
            intermediate_emb = intermediate_emb.squeeze(1)  # [batch, 1, dim] -> [batch, dim]
        
        intermediate_emb = self.dropout(intermediate_emb)
        logits = self.mlp(intermediate_emb)
        return logits


def main():
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--model_id', type=str, default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    parser.add_argument('--trace_dir', type=str, required=True)
    
    args = parser.parse_args()
    trace_dir_full = os.path.join(
        DSET_TO_DIR[args.dataset],
        MODEL_IDS[args.model_id],
        "latency_traces",
        args.trace_dir
    )

    selected_layer = 14
    selected_epoch = 2
    
    if args.dataset == "math":
        df = pd.read_csv(os.path.join(MATH_DIR, 'math3k.csv'))
    elif args.dataset == "mmlu":
        df = pd.read_csv(os.path.join(MMLU_DIR, 'mmlu.csv'))
    elif args.dataset == "gsm8k":
        df = pd.read_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'))
    
    if 'train' in df.columns:
        df = df[df['train'] == 0]
    elif 'category' in df.columns:
        df = df[df['category'] == 'test']
    else:
        exit(1)

    uids = [u for u in df.unique_id.values if os.path.isfile(os.path.join(trace_dir_full, f"{u}.json"))]

    tensor_dim = 4096
    dropout = 0.15
    if args.dataset == 'mmlu':
        hidden_layers = [512, 256]
    else:
        hidden_layers = [2048, 1024]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = SimpleMLPModel(tensor_dim, dropout, hidden_layers)
    model.load_state_dict(
        torch.load(os.path.join(DSET_TO_DIR[args.dataset], MODEL_IDS[args.model_id], "model", f"L{selected_layer}_mlp", f"e{selected_epoch}.pt"),
            weights_only=False)["model_state_dict"]
    )
    model.to(device)
    model.eval()
    
    results = list()
    for uid in tqdm.tqdm(uids):
        with open(os.path.join(trace_dir_full, f"{uid}.json"), 'r') as f:
            trace = json.load(f)
        chain_ids = trace['chain_ids']
        embeddings = list()
        for chain_id in chain_ids:
            embeddings.append(
                torch.from_numpy(
                    np.load(os.path.join(DSET_TO_DIR[args.dataset], MODEL_IDS[args.model_id], "embedding", f"{uid}.chain{chain_id}.npz"))['data']
                )
            )
        
        latency_prediction = 0
        for iteration in trace['iterations']:

            tb = iteration['token_budget']
            activations = list()
            for chain_id in chain_ids:
                idx = int(tb / 16) - 1
                if idx < embeddings[chain_id].size(1):
                    activations.append(embeddings[chain_id][selected_layer][idx])
            if len(activations) != 0:
                activations = torch.stack(activations).to(device)
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()
                with torch.no_grad():
                    model.forward(activations, torch.ones(activations.size(0), 1, dtype=torch.bool, device=device))
                end.record()
                end.synchronize()
                latency_prediction += start.elapsed_time(end)
        
            branch_outs = iteration['branch_outs']
            for branch_target, branch_source in branch_outs:
                chain_ids[branch_target] = chain_ids[branch_source]

        results.append({
            'uid': uid,
            'latency_prediction': latency_prediction / 1000,
        })
    
    results = pd.DataFrame(results)
    print("Average prediction overhead (s)", results.latency_prediction.mean())
    results.to_csv(
        os.path.join(trace_dir_full, "latency_torch.csv"),
        index=False, header=True)


if __name__ == "__main__":
    main()
