import os
import argparse

os.environ['HF_HOME'] = '../../_hf'
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
# os.environ["TRANSFORMERS_OFFLINE"] = "1"
# os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TRANSFORMERS_VERBOSITY'] = "error"
# if os.environ["PYTHONHASHSEED"] != "0":
#     logger.info("Warning, please set environment variable PYTHONHASHSEED to 0 for determinism")

import numpy as np
import torch
from loguru import logger
from tqdm import tqdm

from me_shared import MODEL_REGISTRY
from me_cfg import Configure
from me_core import Core
from me_load import NeoLoader, Metric
from me_util import stabilize, timeit, calculate_statistics
from me_shared import MODEL_REGISTRY, DEVICE
from me_util import stabilize, timeit, format_score, get_attr


def exp(data_name):
    # ...
    core = Core(cfg)
    # core.run_align()

    # test_text = "The capital of France is"
    # test_code = "Paris"
    # core.workflow(test_text, test_code)

    if data_name in ['bigcode', 'human', 'numpy', 'pandas']:
        # load the corpus
        test_texts, test_codes = NeoLoader.load_corpus(data_name=data_name)
    else:
        raise NotImplementedError

    pre_gens = list()
    post_gens = list()
    oracle_gens = list()
    num_iteration = len(test_texts)

    # SEFT: semantic-efficient fine-tuning...
    for datum_idx, (test_text, test_code) in enumerate(zip(test_texts, test_codes)):
        logger.success('@' * 9 + f'{datum_idx + 1}/{num_iteration}' + '@' * 9)

        pre_gen, post_gen, oracle_gen = core.workflow(test_text, test_code)

        logger.debug(f'{pre_gen=}')
        logger.debug(f'{post_gen=}')
        logger.debug(f'{oracle_gen=}')
        pre_gens.append(pre_gen)
        post_gens.append(post_gen)
        oracle_gens.append(oracle_gen)

    logger.info('complex effects')
    Metric.contrast_scoring(pre_gens, post_gens, oracle_gens)
    Metric.contrast_scoring2(pre_gens, post_gens, oracle_gens)


@timeit
def main():
    stabilize()
    torch.cuda.empty_cache()
    exp(cfg.data_name)


# TODO implement the whole thing (layer attribution, loss, etc)
# TODO slightly compare align variants (w/o whitening, etc)
# TODO design RQs
# TODO submit experiments
# TODO write the paper
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--rq', default=1, type=int, help='[1] 2 3')
    parser.add_argument('--data', default='human', type=str, help='[human] ...')
    # parser.add_argument('--src-model', default='qwen3-1.7b', type=str, help='[llama3.2-3b] ...')
    parser.add_argument('--src-model', default='qwen3-4b', type=str, help='[llama3.2-3b] ...')
    parser.add_argument('--tgt-model', default='qwen3-0.6b', type=str, help='[llama3.2-1b] ...')
    # parser.add_argument('--model', default='starcoder2-3b', type=str, help='[codegen] ...')
    parser.add_argument('--subtask', default='approach.0.finetune', type=str, help='[approach.0.finetune] ...')
    args = parser.parse_args()
    logger.info(f'{args=}')

    # RQ_ID = 4
    RQ_ID = args.rq

    # DATA_NAME = 'conala'
    # DATA_NAME = 'ia32'
    # DATA_NAME = 'spider'
    # DATA_NAME = 'tldr'
    # DATA_NAME = 'human'
    # DATA_NAME = 'bigcode'
    DATA_NAME = args.data

    # starcoder2
    # args.model = 'starcoder2-3b'
    # args.model = 'starcoder2-7b'
    # # code llama
    # args.model = 'codellama-7b'
    SRC_MODEL_NAME = MODEL_REGISTRY[args.src_model]
    TGT_MODEL_NAME = MODEL_REGISTRY[args.tgt_model]

    # SUBTASK_CODE = 'approach.1.finetune'
    # SUBTASK_CODE = 'approach.1.me-iter'
    # SUBTASK_CODE = 'approach.1.me-batch'
    SUBTASK_CODE = args.subtask

    cfg = Configure(DATA_NAME, SRC_MODEL_NAME, TGT_MODEL_NAME, SUBTASK_CODE)
    main()
