import os
import argparse

os.environ['HF_HOME'] = '../../_hf'
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
# os.environ["TRANSFORMERS_OFFLINE"] = "1"
# os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TRANSFORMERS_VERBOSITY'] = "error"
# if os.environ["PYTHONHASHSEED"] != "0":
#     logger.info("Warning, please set environment variable PYTHONHASHSEED to 0 for determinism")

import argparse
import pickle

import numpy as np
import torch
from loguru import logger

from me_shared import MODEL_REGISTRY
from me_cfg import Configure
from me_load import NeoLoader
from me_util import stabilize, timeit
from we_core import Core


def visualize(model_name, data_name):
    if data_name in ['conala', 'ia32', 'spider', 'tldr']:
        train_texts, train_codes, test_texts, test_codes = NeoLoader.load_data(data_name)
    elif data_name in ['bigcode', 'human', 'numpy', 'pandas']:
        test_texts, test_codes = NeoLoader.load_corpus(data_name=data_name)
    else:
        raise NotImplementedError

    # do generation for each datum
    num_iteration = len(test_texts)
    core = Core(cfg)

    # assert:
    dict_simis = dict()
    # for option in ['xpolation']:
    for option in ['input-side', 'output-side']:
    # for option in ['input-side', 'output-side', 'xpolation']:
        core.switch_xxx(option)
        batched_simis = list()
        for datum_idx, (test_text, test_code) in enumerate(zip(test_texts, test_codes)):
            logger.success('@' * 9 + f'{datum_idx + 1}/{num_iteration}' + '@' * 9)
            layered_simis = core.pipeline(test_text, option)
            batched_simis.append(layered_simis)
        dict_simis[option] = batched_simis
    core.plot_annotated_curve(dict_simis)

    # # report results
    # Metric.contrast_scoring(pre_gens, post_gens, oracle_gens)
    # Metric.contrast_scoring2(pre_gens, post_gens, oracle_gens)


@timeit
def main():
    # stabilize()
    torch.cuda.empty_cache()
    visualize(MODEL_NAME, DATA_NAME)


# https://kexue.fm/archives/8069
# https://kexue.fm/archives/8321
# https://kexue.fm/archives/8578
# https://kexue.fm/archives/10592
# TODO integrate latent-representation fitting? (12345...) (maybe no longer needed)
# TODO think about the critical uses of "semantic parts"...!!!
# TODO (softmax+cosine) for better probing loss? ...
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--rq', default=1, type=int, help='[1] 2 3')
    parser.add_argument('--data', default='human', type=str, help='[human] ...')
    parser.add_argument('--model', default='codegen', type=str, help='[codegen] ...')
    parser.add_argument('--subtask', default='approach.0.finetune', type=str, help='[approach.0.finetune] ...')
    args = parser.parse_args()
    logger.info(f'{args=}')

    # RQ_ID = 4
    RQ_ID = args.rq

    # DATA_NAME = 'conala'
    # DATA_NAME = 'ia32'
    # DATA_NAME = 'spider'
    # DATA_NAME = 'tldr'
    DATA_NAME = args.data

    # codegen
    # args.model = 'gpt2'
    # args.model = 'codegen'
    # args.model = 'codegen-2b'
    # # codegen2
    # args.model = 'codegen2-1b'
    # args.model = 'codegen2-3.7b'
    # args.model = 'codegen2-7b'
    # args.model = 'codegen25-7b'
    # starcoder2
    # args.model = 'starcoder2-3b'
    # args.model = 'starcoder2-7b'
    # # code llama
    # args.model = 'codellama-7b'
    # args.model = 'llama3.2-1b'
    # args.model = 'qwen3-0.6b'
    # args.model = 'qwen3-8b'
    # MODEL_NAME = MODEL_REGISTRY[args.model]
    MODEL_NAME = 'meta-llama/Llama-2-7b-chat-hf'

    # SUBTASK_CODE = 'approach.1.me'
    # SUBTASK_CODE = 'approach.1.ft'
    SUBTASK_CODE = args.subtask

    cfg = Configure(DATA_NAME, MODEL_NAME, MODEL_NAME, SUBTASK_CODE)
    main()
