import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoConfig
import argparse
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, RobustScaler
import pickle
import os
import json
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
import logging
from utils import LIST_OF_MODELS, LIST_OF_DATASETS, LIST_OF_TRAIN_DATASETS, LIST_OF_NUMERICAL_POSITIONS, LIST_OF_EXACT_POSITIONS, LIST_OF_BLOCK_TYPES, LAYERS_TO_TRACE, LIST_OF_TEST_DATASETS, N_LAYERS, extract_activations_and_labels, validate, init_and_train_classifier, load_ori_data, get_data_idx_list, load_llm, init_and_train_mlp_classifier
import random
import time
from LLM import LLM
from sklearn.model_selection import train_test_split
from transformers import set_seed

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

OUTPUT_DIR = os.getenv("AMLT_OUTPUT_DIR", 'null')

def parse_args():
    
    parser = argparse.ArgumentParser(description='Train a probe on a language model')
    parser.add_argument('--model', type=str, choices=LIST_OF_MODELS, required=True, help='Name of the model')
    parser.add_argument('--prefix', type=str, default='../probing', help='Prefix of the data file')
    parser.add_argument('--dataset', type=str, choices=LIST_OF_TRAIN_DATASETS, required=True, help='Name of the dataset')
    parser.add_argument('--dataset_test', type=str, choices=LIST_OF_TEST_DATASETS, required=True, help='Name of the test dataset')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='Validation ratio')
    parser.add_argument('--probe_at', type=str, default='mlp_output', help='Layer to extract activations from')
    parser.add_argument('--position', type=str, default='-1', choices=LIST_OF_NUMERICAL_POSITIONS + LIST_OF_EXACT_POSITIONS, help='Position to extract activations')
    parser.add_argument("--block_type", type=str, choices=LIST_OF_BLOCK_TYPES, default='e_ques', help="Block type for probing")
    parser.add_argument('--balanced', type=int, default=0, choices=[0, 1], help='Whether to use balanced data (1) or not (0)')
    parser.add_argument('--num_samples', type=str, default='2000', help='Number of samples to use for training and validation')
    parser.add_argument("--deployment_name", type=str, default='gpt-4o_2024-11-20')
    parser.add_argument("--use_mlp", type=int, default=0, choices=[0, 1], help='Whether to use MLP classifier (1) or Logistic Regression (0)')
    parser.add_argument("--seed", type=int, default=42, help='Random seed for reproducibility')
    
    return parser.parse_args()

def main():
    
    args = parse_args()
    args.balanced = bool(args.balanced)
    set_seed(args.seed)
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    llm = load_llm(args.model)
    num_layers = N_LAYERS[llm.model_name]

    if not args.balanced:
        ori_datas = load_ori_data(args.dataset)
    else:
        ori_datas = json.load(open(f'../output/{args.dataset}/{llm.model_name}/balanced_ori_datas.json', 'r'))
    textual_answers = json.load(open(f'../output/{args.dataset}/{llm.model_name}/{"balanced_" if args.balanced else ""}textual_answers.json', 'r'))
    assert len(ori_datas) == len(textual_answers['raw_question']), "Length of original data and textual answers do not match."
    input_output_ids_list = torch.load(f'../output/{args.dataset}/{llm.model_name}/{"balanced_" if args.balanced else ""}input_output_ids.pt', weights_only=True)
    output_ids_list = torch.load(f'../output/{args.dataset}/{llm.model_name}/{"balanced_" if args.balanced else ""}output_ids.pt', weights_only=True)
    data_idx_list = get_data_idx_list(ori_datas, textual_answers, args.num_samples, args.deployment_name)
    
    if not args.balanced:
        ori_datas_test = load_ori_data(args.dataset_test)
    else:
        ori_datas_test = json.load(open(f'../output/{args.dataset_test}/{llm.model_name}/balanced_ori_datas.json', 'r'))
    textual_answers_test = json.load(open(f'../output/{args.dataset_test}/{llm.model_name}/{"balanced_" if args.balanced else ""}textual_answers.json', 'r'))
    assert len(ori_datas_test) == len(textual_answers_test['raw_question']), "Length of original data and textual answers do not match."
    input_output_ids_list_test = torch.load(f'../output/{args.dataset_test}/{llm.model_name}/{"balanced_" if args.balanced else ""}input_output_ids.pt', weights_only=True)
    output_ids_list_test = torch.load(f'../output/{args.dataset_test}/{llm.model_name}/{"balanced_" if args.balanced else ""}output_ids.pt', weights_only=True)
    data_idx_list_test = get_data_idx_list(ori_datas_test, textual_answers_test, args.num_samples, args.deployment_name)

    # train vanilla probing
    all_activations, all_labels, _ = extract_activations_and_labels(data_idx_list, ori_datas, textual_answers, input_output_ids_list, output_ids_list, llm, args.probe_at, args.position, layer=None, blocked_layer=None, num_layers=None, window=None, block_type=args.block_type, block=False, block_all_layers=False, save_dir=None, deployment_name=args.deployment_name)
    all_activations_test, all_labels_test, _ = extract_activations_and_labels(data_idx_list_test, ori_datas_test, textual_answers_test, input_output_ids_list_test, output_ids_list_test, llm, args.probe_at, args.position, layer=None, blocked_layer=None, num_layers=None, window=None, block_type=args.block_type, block=False, block_all_layers=False, save_dir=None, deployment_name=args.deployment_name)
    for layer in range(num_layers):
    
        # Create directories
        if args.use_mlp:
            save_dir = f'{args.prefix}/{args.dataset}/{args.dataset_test}/{args.model}/vanilla/use_mlp_{args.use_mlp}/balanced_{args.balanced}/num_samples_{args.num_samples}/probe_at_{args.probe_at}/position_{args.position}/layer_{layer}/seed_{args.seed}'
        else:
            save_dir = f'{args.prefix}/{args.dataset}/{args.dataset_test}/{args.model}/vanilla/balanced_{args.balanced}/num_samples_{args.num_samples}/probe_at_{args.probe_at}/position_{args.position}/layer_{layer}/seed_{args.seed}'
        if OUTPUT_DIR != 'null':
            save_dir = os.path.join(OUTPUT_DIR, save_dir.lstrip("../"))
        os.makedirs(save_dir, exist_ok=True)
        
        # Set up logging
        logger.info(f"Probing dataset: {args.dataset} with model: {args.model} at position: {args.position} and layer: {layer} using {args.probe_at} activation")
        logger.info(f"Using token position: {args.position}")
        logger.info(f"total training samples: {len(all_activations)}")


        clf_partial = None
        clf = None
        
        # Load datasets
        # all_activations 和 all_labels 是字典形式，key是数据索引，value是对应的激活和标签
        training_data_idx = list(all_activations.keys())
        X = np.array([all_activations[idx][layer] for idx in training_data_idx])
        y = np.array([all_labels[idx] for idx in training_data_idx])
        
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=args.val_ratio, random_state=42, stratify=y)
        
        test_data_idx = list(all_activations_test.keys())
        X_test = np.array([all_activations_test[idx][layer] for idx in test_data_idx])
        y_test = np.array([all_labels_test[idx] for idx in test_data_idx])
        
        # Initialize logging
        log = {
            'args': vars(args),
            'results': {},
        }
        
        if args.use_mlp:
            clf_partial = init_and_train_mlp_classifier(args.seed, X_train, y_train)
            clf = init_and_train_mlp_classifier(args.seed, np.vstack((X_train, X_val)), np.hstack((y_train, y_val)))
        else:
            clf_partial = init_and_train_classifier(args.seed, X_train, y_train)
            clf = init_and_train_classifier(args.seed, np.vstack((X_train, X_val)), np.hstack((y_train, y_val)))
            
        val_metrics, val_preds_probs = validate(clf_partial, X_val, y_val)
        logger.info(f"Validation metrics for partial training: {val_metrics}")

        test_metrics, test_preds_probs = validate(clf, X_test, y_test)
        logger.info(f"Test metrics: {test_metrics}")
        
        log['results'] = {
            'training_data_idx': training_data_idx,
            'test_data_idx': test_data_idx,
            'val_metrics': val_metrics,
            'test_metrics': test_metrics
        }
        # Save the model
        pickle.dump(clf, open(os.path.join(save_dir, f'vanilla_clf_block_type_{args.block_type}.pkl'), 'wb'))
        pickle.dump(val_preds_probs, open(os.path.join(save_dir, f'vanilla_val_preds_probs_block_type_{args.block_type}.pkl'), 'wb'))
        # Save final log
        json.dump(log, open(os.path.join(save_dir, f'vanilla_log_block_type_{args.block_type}.json'), 'w'))
        pickle.dump(test_preds_probs, open(os.path.join(save_dir, f'vanilla_test_preds_probs_block_type_{args.block_type}.pkl'), 'wb'))
    

if __name__ == "__main__":
    main()