from datasets import load_dataset
from pathlib import Path
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report

import socket
import platform
from datetime import datetime

import os
import json

import argparse



def _last_token_pool(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """
    Take representation from the last non-padding token.
    B: batch size
    T: sequence length
    H: hidden state dimension
    
    Arguments
          hidden_states: (B, T, H)
          attention_mask: (B, T)
    Returns: 
          last_token_repr: (B, H)
    """
    lengths = attention_mask.sum(dim=1) - 1  # index of last token, (B,)
    lengths = lengths.clamp(min=0)
    batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
    return hidden_states[batch_idx, lengths, :]  # (B, H)



#get hidden states for a specified layer
@torch.no_grad()
def get_last_token_representations(model, tokenizer, layer, df_train, text_col, batch_size):
    '''
    Get the last token representations for a specified layer.
    # B: batch size
    # T: sequence length
    # H: hidden state dimension
    # N: total number of samples (sum of batch sizes)
    
    Arguments
        model: model from which to get representations, AutoModelForCausalLM
        tokenizer: model's tokenizer, used to tokenize the input text, tokenizer
        layer: layer of the model from which to get last token representations, int
        df_train: dataframe with input text, df
        text_col: column name in df_train containing the input text, str
        batch_size: the batch size for input text dataloader, int

    Returns:
        representations: the last token representations from the specified layer, numpy array
    '''

    texts = df_train[text_col].astype(str).tolist()
    text_loader = torch.utils.data.DataLoader(texts, batch_size=batch_size, shuffle=False)

    representations = []

    for batch_texts in tqdm(text_loader):
        # print(len(batch_texts))
        
        #convert texts to input tokens
        input_tokens = tokenizer(
            list(batch_texts),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048,
        )
        input_ids = input_tokens["input_ids"].to(model.device)
        attention_mask = input_tokens["attention_mask"].to(model.device)

        #get model output
        out = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )

        #get hidden states for a specific layer
        hs = out.hidden_states[layer] #[B, T, H] , #0th layer is the embedding layer
        #get hidden state for last token
        hs_last_token = _last_token_pool(hs, attention_mask)
        
        representations.append(hs_last_token.detach().to(torch.float32).cpu().numpy())

    #concatenate all the representations
    return np.concatenate(representations, axis=0) #(N, H)



def train_linear_probe(X_train, y_train, X_test, y_test):
    '''
    Train a linear probe on the last token representations.
    X_train, X_test: last token representations for training/test set, numpy array (N_train, H) or (N_test, H)
    y_train, y_test: labels for training/test set, numpy array (N_train,) or (N_test,)

    Returns:
        acc: accuracy of the linear probe, float
        c_report: classification report of the linear probe, string
    '''

    clf = Pipeline(
        steps=[
            ("scaler", StandardScaler()),
            ("logreg", LogisticRegression(
                multi_class="multinomial",
                solver="saga",
                max_iter=2000,
                n_jobs=-1,
                random_state=42
            )),
        ]
    )

    clf.fit(X_train, y_train)

    #train accuracy
    y_pred_train = clf.predict(X_train)
    acc_train = accuracy_score(y_train, y_pred_train)
    c_report_train = classification_report(y_train, y_pred_train)

    #test accuracy
    y_pred_test = clf.predict(X_test)
    acc_test = accuracy_score(y_test, y_pred_test)
    c_report_test = classification_report(y_test, y_pred_test)

    return acc_train, acc_test, c_report_train, c_report_test





if __name__ == "__main__":

    ### parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True, choices=['sst', 'ag_news', 'yelp'])
    parser.add_argument('--model_name', type=str, required=True, choices=['llama-0.5B-10BT', 'llama-1B-20BT', 'llama-4B-80BT', 'olmo-1B-30BT', 'olmo-1B-210BT'])
    parser.add_argument('--wd', type=float, required=True) #options depend on the model
    parser.add_argument('--batch_size', type=int, default=128, required=True)
    parser.add_argument('--layer', type=int, required=True) #options depend on the model
    args = parser.parse_args()


    dataset = args.dataset
    model_name = args.model_name
    wd = args.wd
    batch_size = args.batch_size
    layer = args.layer


    ### setup
    project_folder_path = str(Path.cwd().parent)
    linear_probing_folder_path = str(Path.cwd())


    ### load dataset
    print('\n\n\n---------------> Load dataset')

    if dataset == 'ag_news':
        data_hf_path = "hf/ag_news"
        text_col = 'text'
        label_col = 'label'
    elif dataset == 'sst':
        data_hf_path = "hf/sst2"
        text_col = 'sentence'
        label_col = 'label'
    elif dataset == 'yelp':
        data_hf_path = "hf/yelp_polarity"
        text_col = 'text'
        label_col = 'label'
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    df_train = load_dataset(path=data_hf_path, split="train").to_pandas()

    if dataset in ['sst']:
        df_test = load_dataset(path=data_hf_path, split="validation").to_pandas() #use validation set as test set for sst2, for sst2, test set has no labels
    elif dataset in ['ag_news', 'yelp']:
        df_test = load_dataset(path=data_hf_path, split="test").to_pandas()
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    print('# training samples: ', len(df_train))
    print('# testing samples: ', len(df_test))


    ### load model + tokenizer
    print('\n\n\n---------------> Load model + tokenizer')

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if model_name in ['llama-0.5B-10BT', 'llama-1B-20BT']:
        model_folder = 'path/to/model/folder'
        model_path = 'path/to/model/path'
    if model_name in ['llama-4B-80BT']:
        model_folder = 'path/to/model/folder'
        model_path = 'path/to/model/path'
    if model_name in ['olmo-1B-30BT', 'olmo-1B-210BT']:
        model_folder = 'path/to/model/folder'
        model_path = 'path/to/model/path'

    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to(device)
    model.eval()
    print('Model loaded successfully!')
    print('Total # hidden layers in model:', model.config.num_hidden_layers)
    print(model)
    print(model.config)

    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)    
    print('\nTokenizer loaded successfully!')



    ##### get model activations for a specific layer -- create X_train and X_test
    print(f'\n\n\n---------------> Get model activations for layer {layer}')
    print('Batch size:', batch_size)
    print('Get representations for training set...')
    X_train =  get_last_token_representations(model, tokenizer, layer, df_train, text_col, batch_size) #(N_train, H)
    
    print('\nGet representations for test set...')
    X_test = get_last_token_representations(model, tokenizer, layer, df_test, text_col, batch_size) #(N_test, H)

    # encode labels for input texts -- create y_train and y_test
    label_encoder = LabelEncoder()
    y_train = label_encoder.fit_transform(df_train[label_col].astype(str).values)
    y_test = label_encoder.transform(df_test[label_col].astype(str).values)
    n_classes = len(label_encoder.classes_)

    #check shapes
    print('\nB: batch size')
    print('T: sequence length')
    print('H: hidden state dimension')
    print('N: total number of samples (sum of batch sizes)')

    print('')
    print('X_train (N_train, H):', X_train.shape)
    print('X_test (N_test, H):', X_test.shape)
    print('y_train (N_train,):', y_train.shape)
    print('y_test (N_test,):', y_test.shape)



    ##### train linear probe
    print('\n\n\n---------------> Train linear probe')
    acc_train, acc_test, c_report_train, c_report_test = train_linear_probe(X_train, y_train, X_test, y_test)
    print(f"\nAccuracy, train: {acc_train:.4f}")
    print('Classification report, train: ')
    print(c_report_train)
    print(f"\nAccuracy, test: {acc_test:.4f}")
    print('Classification report, test: ')
    print(c_report_test)



    ##### save results
    print('\n\n\n---------------> Save results')

    results_dict = {
        'model': model_name,
        'wd': wd,
        'layer': layer,
        'total_n_hidden_layers': model.config.num_hidden_layers,
        'dataset': dataset,
        'n_classes': n_classes,
        'n_train': X_train.shape[0],
        'n_test': X_test.shape[0],
        'hidden_state_dim': X_train.shape[1],
        'train_acc': acc_train,
        'test_acc': acc_test,
        'c_report_train': c_report_train,
        'c_report_test': c_report_test,
        'start_time': start_time.strftime('%Y-%m-%d %H:%M:%S') #to identify the job
    }

    #save results to json file -- add new line to file if it already exists
    filename = f'{linear_probing_folder_path}/results/{model_name}-{dataset}.json'

    mode = 'a' if os.path.exists(filename) else 'w' #open file in append mode if exists, else write mode

    with open(filename, mode, encoding='utf-8') as f:
        json.dump(results_dict, f)
        f.write('\n')  # add newline so each dict is a separate line

    print(f'Results saved to: {filename}')


