import argparse
import json
import numpy as np
import os
import pandas as pd
import torch
import tqdm
from client import apply_chat_template
from config import (
    MATH_DIR, MATH_MAX_LEN, MATH_NUM_CHAINS, MATH_PROBE_FREQ,
    MMLU_DIR, MMLU_MAX_LEN, MMLU_NUM_CHAINS, MMLU_PROBE_FREQ,
    GSM8K_DIR, GSM8K_MAX_LEN, GSM8K_NUM_CHAINS, GSM8K_PROBE_FREQ,
    MODEL_IDS,
)
from sklearn.model_selection import train_test_split
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import convert_math_data_setting_to_str, process_math_id, convert_data_setting_to_str, load_gzip_file


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset', type=str, required=True)
    parser.add_argument('-m', '--model', type=str, required=True)
    parser.add_argument('-l', '--layer', type=int, required=True)
    parser.add_argument('-w', '--warmup-windows', type=int, required=True)
    parser.add_argument('-s', '--stable-windows', type=int, required=True)
    parser.add_argument('-i', '--interval', type=int, required=True)
    parser.add_argument('-e', '--max-seq', type=int, required=True)
    parser.add_argument('-v', '--val-frac', type=float, default=0.1, help='Fraction of training data to use for validation')
    parser.add_argument('-r', '--random-seed', type=int, default=42, help='Random seed for train/val split')
    return parser.parse_args()


def parse_gsm8k_ans(ans_raw):
    ans_raw = str(ans_raw)
    if str.isnumeric(str(ans_raw)):
        superscript_map = {
            '¹': '1', '²': '2', '³': '3', '⁴': '4',
            '⁵': '5', '⁶': '6', '⁷': '7', '⁸': '8', '⁹': '9'
        }
        for k, v in superscript_map.items():
            if ans_raw.endswith(k):
                return float(int(ans_raw[:-1]) ** int(v))
        else:
            return float(ans_raw)
        
    else:
        return np.nan


def load_metadata(dataset, model_id, for_training=True):

    cols_to_keep = ['unique_id', 'chain_id', 'tokens', 'correct', 'train', 'type', 'curr_answer', 'answer']
    if dataset == 'math':
        postprocessed = list()
        for fn in os.listdir(os.path.join(MATH_DIR, MODEL_IDS[model_id], 'probe_postprocessed')):
            if fn.endswith('.csv'):
                postprocessed.append(pd.read_csv(os.path.join(MATH_DIR, MODEL_IDS[model_id], 'probe_postprocessed', fn)))
        postprocessed = pd.concat(postprocessed, axis=0, ignore_index=True)
        if for_training:
            postprocessed = postprocessed[postprocessed['type'] == 'intermediate']
            postprocessed = postprocessed[postprocessed['tokens'] < MATH_MAX_LEN]

        split_df = pd.read_csv(os.path.join(MATH_DIR, 'math3k.csv'))
        split_df = split_df[['unique_id', 'answer', 'train']]
        postprocessed = pd.merge(postprocessed, split_df, on='unique_id', how='inner')
        postprocessed['curr_answer'] = postprocessed['curr_answer'].astype(str)
        postprocessed['answer'] = postprocessed['answer'].astype(str)
        postprocessed['correct'] = postprocessed['curr_answer'] == postprocessed['answer']
        postprocessed = postprocessed[cols_to_keep]
        return postprocessed
    
    elif dataset == 'mmlu':
        cols_to_keep.append('curr_answer_raw')
        mmlu = pd.read_csv(os.path.join(MMLU_DIR, 'mmlu.csv'))
        postprocessed = list()
        for fn in os.listdir(os.path.join(MMLU_DIR, MODEL_IDS[model_id], 'probe_postprocessed')):
            if fn.endswith('.csv'):
                df_loaded_raw = pd.read_csv(os.path.join(MMLU_DIR, MODEL_IDS[model_id], 'probe_postprocessed', fn))
                cleaned = list()
                for _, df_loaded in df_loaded_raw.groupby(by='chain_id'):
                    final_ans = df_loaded[df_loaded['type'] == 'final'].curr_answer.values[0]
                    """
                    if final_ans == 'Invalid':
                        # convert to the last intermediate answer to resolve parsing issue
                        df_itm = df_loaded[df_loaded['type'] == 'intermediate']
                        t, a = df_itm.tokens.values, df_itm.curr_answer.values
                        last_itm_ans = a[np.argmax(t)]
                        all_ans = df_loaded.curr_answer.values
                        all_ans[df_loaded['type'] == 'final'] = last_itm_ans
                        df_loaded['curr_answer'] = all_ans
                    """
                    cleaned.append(df_loaded)
                postprocessed.append(pd.concat(cleaned, axis=0, ignore_index=True))
        postprocessed = pd.concat(postprocessed, axis=0, ignore_index=True)
        if for_training:
            postprocessed = postprocessed[postprocessed['type'] == 'intermediate']
            postprocessed = postprocessed[postprocessed['tokens'] < MMLU_MAX_LEN]
        # Updated category assignments: train=1, val=0.5, test=0
        mmlu['train'] = mmlu['category'].map({'train': 1, 'val': 0.5, 'test': 0})
        # Keep all categories (train, val, test) - no filtering
        postprocessed = pd.merge(postprocessed,
                mmlu[['unique_id', 'answer', 'train']], on='unique_id', how='inner')
        postprocessed['curr_answer'] = postprocessed['curr_answer'].astype(str)
        postprocessed['answer'] = postprocessed['answer'].astype(str)
        postprocessed['correct'] = postprocessed['curr_answer'] == postprocessed['answer']
        postprocessed = postprocessed[cols_to_keep]
        return postprocessed

    elif dataset == 'gsm8k':
        gsm8k = pd.read_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'))
        postprocessed = list()
        for fn in os.listdir(os.path.join(GSM8K_DIR, MODEL_IDS[model_id], 'probe_postprocessed')):
            if fn.endswith('.csv'):
                postprocessed.append(pd.read_csv(os.path.join(GSM8K_DIR, MODEL_IDS[model_id], 'probe_postprocessed', fn)))
        postprocessed = pd.concat(postprocessed, axis=0, ignore_index=True)

        if for_training:
            postprocessed = postprocessed[postprocessed['type'] == 'intermediate']
            postprocessed = postprocessed[postprocessed['tokens'] < GSM8K_MAX_LEN]

        uid, gt, is_train = list(), list(), list()
        gsm8k['train'] = gsm8k['category'].map({'train': 1, 'val': 0.5, 'test': 0})
        postprocessed = pd.merge(postprocessed, gsm8k[['unique_id', 'answer', 'train']], on='unique_id', how='inner')
        postprocessed['curr_answer'] = postprocessed.apply(lambda row: parse_gsm8k_ans(row['curr_answer']), axis=1)
        postprocessed['correct'] = postprocessed['curr_answer'] == postprocessed['answer']
        postprocessed = postprocessed[cols_to_keep]
        return postprocessed
    
    return None


def get_split_file_path(dataset, seed, val_frac):
    """Get the path for the split file."""
    if dataset == 'math':
        data_dir = MATH_DIR
    elif dataset == 'mmlu':
        data_dir = MMLU_DIR
    elif dataset == 'gsm8k':
        data_dir = GSM8K_DIR
    else:
        raise ValueError(f"Unknown dataset: {dataset}")
    
    split_dir = os.path.join(data_dir, 'split')
    os.makedirs(split_dir, exist_ok=True)
    return os.path.join(split_dir, f's{seed}_v{val_frac}.csv')


def load_or_create_split(metadata, dataset, val_frac, seed=42):
    """
    Load existing split or create a new one and save it.
    
    Args:
        metadata: DataFrame with training data
        dataset: Dataset name ('math', 'gpqa', or 'mmlu')
        val_frac: Validation fraction
        seed: Random seed for reproducibility
    
    Returns:
        train_metadata, val_metadata: Split training data
    """
    split_file_path = get_split_file_path(dataset, seed, val_frac)
    
    # Check if split file exists
    if os.path.exists(split_file_path):
        print(f"Loading existing split from {split_file_path}")
        split_df = pd.read_csv(split_file_path)
        
        # Get train and val unique_ids from the split file
        train_unique_ids = split_df[split_df['split'] == 'train']['unique_id'].values
        val_unique_ids = split_df[split_df['split'] == 'val']['unique_id'].values
        
        print(f"Loaded split - Train questions: {len(train_unique_ids)}, Val questions: {len(val_unique_ids)}")
        
    else:
        print(f"Creating new split and saving to {split_file_path}")
        
        # Separate train and test data
        train_metadata = metadata[metadata['train'] == 1]
        
        # Get unique question IDs from train set
        train_unique_ids = train_metadata['unique_id'].unique()
        print(f"Total unique questions in train set: {len(train_unique_ids)}")
        
        # Split unique_ids into train and val
        train_question_ids, val_question_ids = train_test_split(
            train_unique_ids, 
            test_size=val_frac,
            random_state=seed,
            stratify=None
        )
        
        print(f"Created split - Train questions: {len(train_question_ids)}, Val questions: {len(val_question_ids)}")
        
        # Save the split
        split_data = []
        for uid in train_question_ids:
            split_data.append({'unique_id': uid, 'split': 'train'})
        for uid in val_question_ids:
            split_data.append({'unique_id': uid, 'split': 'val'})
        
        split_df = pd.DataFrame(split_data)
        split_df.to_csv(split_file_path, index=False)
        print(f"Split saved to {split_file_path}")
        
        train_unique_ids = train_question_ids
        val_unique_ids = val_question_ids
    
    # Create train and val metadata
    train_metadata_split = metadata[metadata['unique_id'].isin(train_unique_ids)]
    val_metadata = metadata[metadata['unique_id'].isin(val_unique_ids)]
    
    return train_metadata_split, val_metadata


def create_tensor(
        metadata, embedding_dir, output_dir,
        layer_idx, warmup, stable, interval, max_seq,
        probing_frequency,
        cache_and_load=True,
        val_frac=0.1,
        seed=42,
        dataset=None
    ):

    has_val_split = (metadata['train'] == 0.5).any()
    if has_val_split:
        val_frac = 0.1
        seed = 0

    tensor_dir = convert_data_setting_to_str(val_frac, seed, layer_idx, warmup, stable, interval, max_seq)
    tensor_dir = os.path.join(output_dir, tensor_dir)

    if not os.path.isdir(tensor_dir):
        
        if has_val_split:
            print("Data is pre-split to train/val/test")
            # Use existing splits: train=1, val=0.5, test=0
            train_metadata_split = metadata[metadata['train'] == 1]
            val_metadata = metadata[metadata['train'] == 0.5]
            test_metadata = metadata[metadata['train'] == 0]
            print(f"Using existing val split from dataset")
        else:
            print("Creating train/val/test split")
            # Use the split function to get train and val splits
            train_metadata_split, val_metadata = load_or_create_split(metadata, dataset, val_frac, seed)
            # Get test data
            test_metadata = metadata[metadata['train'] == 0]
        
        print(f"Train samples: {len(train_metadata_split)}")
        print(f"Val samples: {len(val_metadata)}")
        print(f"Test samples: {len(test_metadata)}")

        # warmup, stable, max_seq are in terms of windows ("interval" tokens per window)
        train_metadata_filtered = train_metadata_split[train_metadata_split['tokens'] % interval == 0]
        val_metadata_filtered = val_metadata[val_metadata['tokens'] % interval == 0]
        test_metadata_filtered = test_metadata[test_metadata['tokens'] % interval == 0]
        
        train_metadata_filtered = train_metadata_filtered.sort_values(
            by=['unique_id', 'chain_id', 'tokens'], ascending=[True, True, True])
        val_metadata_filtered = val_metadata_filtered.sort_values(
            by=['unique_id', 'chain_id', 'tokens'], ascending=[True, True, True])
        test_metadata_filtered = test_metadata_filtered.sort_values(
            by=['unique_id', 'chain_id', 'tokens'], ascending=[True, True, True])

        if (not os.path.isdir(tensor_dir)) or (not cache_and_load):
            # Process train set
            print("Processing train set...")
            train_total = train_metadata_filtered[['unique_id', 'chain_id']].drop_duplicates().shape[0]
            train_pbar = tqdm.tqdm(total=train_total)
            train_ret = [list() for _ in range(4)]  # info, prompt, intermediate, label
            
            grb_prompt_train = train_metadata_filtered.groupby('unique_id')
            for unique_id, prompt_df in grb_prompt_train:
                prompt_embedding_path = os.path.join(embedding_dir, f"{unique_id}.prompt.lasttoken.npz")

                if not os.path.isfile(prompt_embedding_path):
                    continue

                # prompt_embedding = torch.load(prompt_embedding_path, weights_only=True, map_location='cpu')
                # prompt_embedding = load_gzip_file(prompt_embedding_path)
                prompt_embedding = torch.from_numpy(np.load(prompt_embedding_path)['data'])
                prompt_embedding = prompt_embedding[layer_idx]
                grb_chain = prompt_df.groupby('chain_id')
                for chain_id, chain_df in grb_chain:
                    train_pbar.update(1)
                    train_pbar.set_description(f"Processing train {unique_id} chain {chain_id}")
                    tokens = chain_df.tokens.values
                    corrects = chain_df.correct.values
                    if chain_df.shape[0] >= warmup:
                        chain_embedding_path = os.path.join(embedding_dir, f"{unique_id}.chain{chain_id}.npz")
                        # chain_embedding = torch.load(chain_embedding_path, weights_only=True, map_location='cpu')
                        # chain_embedding = load_gzip_file(chain_embedding_path)
                        chain_embedding = torch.from_numpy(np.load(chain_embedding_path)['data'])
                        positions = (tokens / probing_frequency).astype(int) - 1
                        if np.amax(positions) >= chain_embedding.shape[1]:
                            print(chain_df)
                            exit(0)
                        chain_embedding = chain_embedding[layer_idx, positions, :]
                        train_ret[0].append(chain_df[['unique_id', 'chain_id', 'tokens']].iloc[warmup - 1:])

                        for idx in range(warmup - 1, len(tokens)):
                            correct_and_stable = int(np.amin(corrects[idx:min(len(tokens), idx + stable)]) == 1)
                            embedding_slice = chain_embedding[max(0, idx - max_seq + 1):idx+1, :]
                            train_ret[1].append(prompt_embedding.cpu())
                            train_ret[2].append(embedding_slice.cpu())
                            train_ret[3].append(correct_and_stable)
            train_pbar.close()
            
            # Process val set
            print("Processing val set...")
            val_total = val_metadata_filtered[['unique_id', 'chain_id']].drop_duplicates().shape[0]
            val_pbar = tqdm.tqdm(total=val_total)
            val_ret = [list() for _ in range(4)]  # info, prompt, intermediate, label
            
            grb_prompt_val = val_metadata_filtered.groupby('unique_id')
            for unique_id, prompt_df in grb_prompt_val:
                prompt_embedding_path = os.path.join(embedding_dir, f"{unique_id}.prompt.lasttoken.npz")
                if not os.path.isfile(prompt_embedding_path):
                    continue
                # prompt_embedding = torch.load(prompt_embedding_path, weights_only=True, map_location='cpu')
                prompt_embedding = torch.from_numpy(np.load(prompt_embedding_path)['data'])
                prompt_embedding = prompt_embedding[layer_idx]
                grb_chain = prompt_df.groupby('chain_id')
                for chain_id, chain_df in grb_chain:
                    val_pbar.update(1)
                    val_pbar.set_description(f"Processing val {unique_id} chain {chain_id}")
                    tokens = chain_df.tokens.values
                    corrects = chain_df.correct.values
                    if chain_df.shape[0] >= warmup:
                        chain_embedding_path = os.path.join(embedding_dir, f"{unique_id}.chain{chain_id}.npz")
                        # chain_embedding = torch.load(chain_embedding_path, weights_only=True, map_location='cpu')
                        chain_embedding = torch.from_numpy(np.load(chain_embedding_path)['data'])
                        positions = (tokens / probing_frequency).astype(int) - 1
                        if np.amax(positions) >= chain_embedding.shape[1]:
                            print(chain_df)
                            exit(0)
                        chain_embedding = chain_embedding[layer_idx, positions, :]
                        val_ret[0].append(chain_df[['unique_id', 'chain_id', 'tokens']].iloc[warmup - 1:])

                        for idx in range(warmup - 1, len(tokens)):
                            correct_and_stable = int(np.amin(corrects[idx:min(len(tokens), idx + stable)]) == 1)
                            embedding_slice = chain_embedding[max(0, idx - max_seq + 1):idx+1, :]
                            val_ret[1].append(prompt_embedding.cpu())
                            val_ret[2].append(embedding_slice.cpu())
                            val_ret[3].append(correct_and_stable)
            val_pbar.close()
            
            # Process test set
            print("Processing test set...")
            test_total = test_metadata_filtered[['unique_id', 'chain_id']].drop_duplicates().shape[0]
            test_pbar = tqdm.tqdm(total=test_total)
            test_ret = [list() for _ in range(4)]  # info, prompt, intermediate, label
            
            grb_prompt_test = test_metadata_filtered.groupby('unique_id')
            for unique_id, prompt_df in grb_prompt_test:
                prompt_embedding_path = os.path.join(embedding_dir, f"{unique_id}.prompt.lasttoken.npz")
                if not os.path.isfile(prompt_embedding_path):
                    continue
                # prompt_embedding = torch.load(prompt_embedding_path, weights_only=True, map_location='cpu')
                prompt_embedding = torch.from_numpy(np.load(prompt_embedding_path)['data'])
                prompt_embedding = prompt_embedding[layer_idx]
                grb_chain = prompt_df.groupby('chain_id')
                for chain_id, chain_df in grb_chain:
                    test_pbar.update(1)
                    test_pbar.set_description(f"Processing test {unique_id} chain {chain_id}")
                    tokens = chain_df.tokens.values
                    corrects = chain_df.correct.values
                    if chain_df.shape[0] >= warmup:
                        chain_embedding_path = os.path.join(embedding_dir, f"{unique_id}.chain{chain_id}.npz")
                        # chain_embedding = torch.load(chain_embedding_path, weights_only=True, map_location='cpu')
                        chain_embedding = torch.from_numpy(np.load(chain_embedding_path)['data'])
                        positions = (tokens / probing_frequency).astype(int) - 1
                        if np.amax(positions) >= chain_embedding.shape[1]:
                            print(chain_df)
                            exit(0)
                        chain_embedding = chain_embedding[layer_idx, positions, :]
                        test_ret[0].append(chain_df[['unique_id', 'chain_id', 'tokens']].iloc[warmup - 1:])

                        for idx in range(warmup - 1, len(tokens)):
                            correct_and_stable = int(np.amin(corrects[idx:min(len(tokens), idx + stable)]) == 1)
                            embedding_slice = chain_embedding[max(0, idx - max_seq + 1):idx+1, :]
                            test_ret[1].append(prompt_embedding.cpu())
                            test_ret[2].append(embedding_slice.cpu())
                            test_ret[3].append(correct_and_stable)
            test_pbar.close()

            if cache_and_load:
                os.makedirs(tensor_dir, exist_ok=True)

            ret_processed = list()
            for dname, (info_list, prompt_list, intermediate_list, label_list) in zip(
                ['train', 'val', 'test'], [train_ret, val_ret, test_ret]
            ):
                print(f"Processing {dname} set")
                print("construct sample info")

                if info_list:
                    columns = info_list[0].columns
                    data_dict = {col: [] for col in columns}
                    for df in info_list:
                        for col in columns:
                            data_dict[col].extend(df[col].values)
                    info_df = pd.DataFrame(data_dict)
                else:
                    info_df = pd.DataFrame()
                
                print('construct prompt embedding')
                prompt_tensor_full = torch.stack(prompt_list, dim=0).cpu()
                print("construct intermediate tensor embedding")
                intermediate_list_full = intermediate_list
                print('construct label tensor')
                label_tensor_full = torch.tensor(label_list, device='cpu')

                if cache_and_load:
                    info_df.to_csv(os.path.join(tensor_dir, f"{dname}.info.csv"), index=False, header=True)
                    pbar = tqdm.tqdm(total=3)
                    pbar.update(1)
                    pbar.set_description(f"Saving prompts")
                    torch.save(prompt_tensor_full, os.path.join(tensor_dir, f"{dname}.prompt.pt"))
                    pbar.update(1)
                    pbar.set_description(f"Saving intermediate")
                    torch.save(intermediate_list_full, os.path.join(tensor_dir, f"{dname}.intermediate.pt"))
                    pbar.update(1)
                    pbar.set_description(f"Saving labels")
                    torch.save(label_tensor_full, os.path.join(tensor_dir, f"{dname}.label.pt"))
                    pbar.close()
                
                ret_processed.append((info_df, prompt_tensor_full, intermediate_list_full, label_tensor_full))
            
            return ret_processed

    # read cached tensors from disk
    print(f"Reading cached tensor from path {tensor_dir}")
    loaded = list()
    for dset_name in ['train', 'val', 'test']:
        info = pd.read_csv(os.path.join(tensor_dir, f"{dset_name}.info.csv"))
        prompt = torch.load(os.path.join(tensor_dir, f"{dset_name}.prompt.pt"), weights_only=True, map_location='cpu')
        intermediate_list = torch.load(os.path.join(tensor_dir, f"{dset_name}.intermediate.pt"), weights_only=True, map_location='cpu')
        label = torch.load(os.path.join(tensor_dir, f"{dset_name}.label.pt"), weights_only=True, map_location='cpu')
        loaded.append((info, prompt, intermediate_list, label))
    return loaded


def create_tensor_for_dataset(dataset, model_id, layer_idx, warmup, stable, interval, max_seq, cache_and_load=True, val_frac=0.1, seed=42):

    metadata = load_metadata(dataset, model_id)
    model_name = MODEL_IDS[model_id]
    if dataset == 'math':
        embedding_dir, output_dir = os.path.join(MATH_DIR, model_name, 'embedding'), os.path.join(MATH_DIR, model_name, 'tensor')
        probing_frequency = MATH_PROBE_FREQ
    elif dataset == 'mmlu':
        embedding_dir, output_dir = os.path.join(MMLU_DIR, model_name, 'embedding'), os.path.join(MMLU_DIR, model_name, 'tensor')
        probing_frequency = MMLU_PROBE_FREQ
    elif dataset == 'gsm8k':
        embedding_dir, output_dir = os.path.join(GSM8K_DIR, model_name, 'embedding'), os.path.join(GSM8K_DIR, model_name, 'tensor')
        probing_frequency = GSM8K_PROBE_FREQ

    tensors = create_tensor(metadata, embedding_dir, output_dir, layer_idx, warmup, stable, interval, max_seq, probing_frequency, cache_and_load, val_frac, seed=seed, dataset=dataset)
    return tensors


if __name__ == "__main__":
    args = parse_args()
    create_tensor_for_dataset(
        args.dataset, args.model, args.layer, args.warmup_windows, args.stable_windows,
        args.interval, args.max_seq, val_frac=args.val_frac, seed=args.random_seed
    )
