import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoConfig
import argparse
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, RobustScaler
import pickle
import os
import json
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
import logging
from utils import LIST_OF_MODELS, LIST_OF_DATASETS, LIST_OF_TRAIN_DATASETS, LIST_OF_NUMERICAL_POSITIONS, LIST_OF_EXACT_POSITIONS, LIST_OF_BLOCK_TYPES, LAYERS_TO_TRACE, LIST_OF_TEST_DATASETS, N_LAYERS, extract_activations_and_labels, validate, init_and_train_classifier, load_ori_data, get_data_idx_list, load_llm, init_and_train_mlp_classifier
import random
import time
from LLM import LLM
from sklearn.model_selection import train_test_split
from transformers import set_seed

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

OUTPUT_DIR = os.getenv("AMLT_OUTPUT_DIR", 'null')

def parse_args():
    
    parser = argparse.ArgumentParser(description='Train a probe on a language model')
    parser.add_argument('--model', type=str, choices=LIST_OF_MODELS, required=True, help='Name of the model')
    parser.add_argument('--prefix', type=str, default='../probing', help='Prefix of the data file')
    parser.add_argument('--dataset', type=str, choices=LIST_OF_TRAIN_DATASETS, required=True, help='Name of the dataset')
    parser.add_argument('--dataset_test', type=str, choices=LIST_OF_TEST_DATASETS, required=True, help='Name of the test dataset')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='Validation ratio')
    parser.add_argument('--probe_at', type=str, default='mlp_output', help='Layer to extract activations from')
    parser.add_argument('--position', type=str, default='-1', choices=LIST_OF_NUMERICAL_POSITIONS + LIST_OF_EXACT_POSITIONS, help='Position to extract activations')
    parser.add_argument("--window", type=int, default=None, help="Window size for attention analysis")
    parser.add_argument("--block_type", type=str, choices=LIST_OF_BLOCK_TYPES, default='e_ques', help="Block type for probing")
    parser.add_argument('--balanced', type=int, default=0, choices=[0, 1], help='Whether to use balanced data (1) or not (0)')
    parser.add_argument('--num_samples', type=str, default='2000', help='Number of samples to use for training and validation')
    parser.add_argument("--deployment_name", type=str, default='gpt-4o_2024-11-20')
    parser.add_argument("--use_mlp", type=int, default=0, choices=[0, 1], help='Whether to use MLP classifier (1) or Logistic Regression (0)')
    parser.add_argument("--seed", type=int, default=42, help='Random seed for reproducibility')
    
    return parser.parse_args()

def main():
    
    args = parse_args()
    args.balanced = bool(args.balanced)
    assert args.window is None, "Window size must be None"
    set_seed(args.seed)
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    llm = load_llm(args.model)
    num_layers = N_LAYERS[llm.model_name]
    
    if not args.balanced:
        ori_datas_test = load_ori_data(args.dataset_test)
    else:
        ori_datas_test = json.load(open(f'../output/{args.dataset_test}/{llm.model_name}/balanced_ori_datas.json', 'r'))
    textual_answers_test = json.load(open(f'../output/{args.dataset_test}/{llm.model_name}/{"balanced_" if args.balanced else ""}textual_answers.json', 'r'))
    assert len(ori_datas_test) == len(textual_answers_test['raw_question']), "Length of original data and textual answers do not match."
    input_output_ids_list_test = torch.load(f'../output/{args.dataset_test}/{llm.model_name}/{"balanced_" if args.balanced else ""}input_output_ids.pt', weights_only=True)
    output_ids_list_test = torch.load(f'../output/{args.dataset_test}/{llm.model_name}/{"balanced_" if args.balanced else ""}output_ids.pt', weights_only=True)
    data_idx_list_test = get_data_idx_list(ori_datas_test, textual_answers_test, args.num_samples, args.deployment_name)

    all_activations_test, all_labels_test, all_block_configs_test = extract_activations_and_labels(data_idx_list_test, ori_datas_test, textual_answers_test, input_output_ids_list_test, output_ids_list_test, llm, args.probe_at, args.position, layer=None, blocked_layer=None, num_layers=num_layers, window=args.window, block_type=args.block_type, block=True, block_all_layers=True, save_dir=None, deployment_name=args.deployment_name)
        
    # Now probing with knockout attention without retraining
    inference_layers = range(num_layers)
    for layer in inference_layers:
        
        logger.info(f"Processing layer {layer} for model {args.model}")

        # Create directories
        if args.use_mlp:
            save_dir = f'{args.prefix}/{args.dataset}/{args.dataset_test}/{args.model}/retrain_False/blocked_all_layers/use_mlp_{args.use_mlp}/balanced_{args.balanced}/num_samples_{args.num_samples}/probe_at_{args.probe_at}/position_{args.position}/layer_{layer}/seed_{args.seed}'
        else:
            save_dir = f'{args.prefix}/{args.dataset}/{args.dataset_test}/{args.model}/retrain_False/blocked_all_layers/balanced_{args.balanced}/num_samples_{args.num_samples}/probe_at_{args.probe_at}/position_{args.position}/layer_{layer}/seed_{args.seed}'
        if OUTPUT_DIR != 'null':
            save_dir = os.path.join(OUTPUT_DIR, save_dir.lstrip("../"))
        os.makedirs(save_dir, exist_ok=True)
        
        # Set up logging
        logger.info(f"Probing dataset: {args.dataset} with model: {args.model} at position: {args.position} and layer: {layer} using {args.probe_at} activation")
        logger.info(f"Using token position: {args.position}")
        
        # training = True
        # if os.path.exists(os.path.join(save_dir, f'knockout_attn_window_{args.window}_block_type_{args.block_type}_clf.pkl')):
        #     training = False
        #     logger.info(f"Model already trained, loading from {os.path.join(save_dir, f'knockout_attn_window_{args.window}_block_type_{args.block_type}_clf.pkl')}")
        #     clf = pickle.load(open(os.path.join(save_dir, f'knockout_attn_window_{args.window}_block_type_{args.block_type}_clf.pkl'), 'rb'))
        # else:
        #     clf_partial = None
        #     clf = None
        
        if args.use_mlp:
            vanilla_probe_prefix = f"../probing/{args.dataset}/{args.dataset}_test/{args.model}/vanilla/use_mlp_{args.use_mlp}/balanced_{args.balanced}/num_samples_{args.num_samples}/probe_at_{args.probe_at}/position_{args.position}/layer_{layer}/seed_{args.seed}"
        else:
            vanilla_probe_prefix = f"../probing/{args.dataset}/{args.dataset}_test/{args.model}/vanilla/balanced_{args.balanced}/num_samples_{args.num_samples}/probe_at_{args.probe_at}/position_{args.position}/layer_{layer}/seed_{args.seed}"
        assert os.path.exists(os.path.join(vanilla_probe_prefix, f'vanilla_clf_block_type_{args.block_type}.pkl')), "Vanilla classifier not found, please run the vanilla probing first."
        clf = pickle.load(open(os.path.join(vanilla_probe_prefix, f'vanilla_clf_block_type_{args.block_type}.pkl'), 'rb'))
        
        logger.info(f"Loading test dataset")
        logger.info(f"Test dataset loaded successfully")
        test_data_idx = list(all_activations_test.keys())
        X_test = np.array([all_activations_test[idx][layer] for idx in test_data_idx])
        y_test = np.array([all_labels_test[idx] for idx in test_data_idx])
        
        # Initialize logging
        log = {
            'args': vars(args),
            'results': {},
        }
        # if training:
        #     clf_partial = init_and_train_classifier(args.seed, X_train, y_train)
        #     clf = init_and_train_classifier(args.seed, np.vstack((X_train, X_val)), np.hstack((y_train, y_val)))
        #     val_metrics, val_preds_probs = validate(clf_partial, X_val, y_val)
        #     logger.info(f"Validation metrics for partial training: {val_metrics}")

        test_metrics, test_preds_probs = validate(clf, X_test, y_test)
        logger.info(f"Test metrics: {test_metrics}")
        
        # if training:
        #     log['results'] = {
        #         'training_data_idx': training_data_idx,
        #         'test_data_idx': test_data_idx,
        #         'all_block_configs': all_block_configs,
        #         'all_block_configs_test': all_block_configs_test,
        #         'val_metrics': val_metrics,
        #         'test_metrics': test_metrics
        #     }
        #     # Save the model
        #     pickle.dump(clf, open(os.path.join(save_dir, f'knockout_attn_window_{args.window}_block_type_{args.block_type}_clf.pkl'), 'wb'))
        #     pickle.dump(val_preds_probs, open(os.path.join(save_dir, f'knockout_attn_window_{args.window}_block_type_{args.block_type}_val_preds_probs.pkl'), 'wb'))
        # else:
        log['results'] = {
            'test_data_idx': test_data_idx,
            'all_block_configs_test': all_block_configs_test,
            'test_metrics': test_metrics
        }
            
        # Save final log
        json.dump(log, open(os.path.join(save_dir, f'knockout_attn_window_{args.window}_block_type_{args.block_type}_log.json'), 'w'))
        pickle.dump(test_preds_probs, open(os.path.join(save_dir, f'knockout_attn_window_{args.window}_block_type_{args.block_type}_test_preds_probs.pkl'), 'wb'))
    

if __name__ == "__main__":
    main()