R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/ll/ablate/ablate_02.py


CUDA_VISIBLE_DEVICES=0 python local_scripts/ll/ablate/ablate_02.py
CUDA_VISIBLE_DEVICES=1 python local_scripts/ll/ablate/ablate_02.py
CUDA_VISIBLE_DEVICES=2 python local_scripts/ll/ablate/ablate_02.py
CUDA_VISIBLE_DEVICES=3 python local_scripts/ll/ablate/ablate_02.py

"""
from importlib import reload
import random
import os

import numpy as np
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.util import hf_util

from em.projects.wino import nmf_components_fisher as ncf

from em.projects.ll import hans_ablation_experiment as HAE
from em.projects.ll import hans_components_context as HCC
from em.projects.ll import hans_merging_context as HMC


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

EXP_RESULTS_DIR = os.path.join(EXPS_DIR, 'exps/hans_ablate_dev1')

###############################################################################

PEF_FILEPATH = "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.10k.131072.h5"
NMF_FILEPATH = "spH.nmf_decomp.c1024_2kIters_{n_fisher_values}pe." + PEF_FILEPATH

FISHER_FILEPATH = "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.h5"

MODEL_PATTERN = 'connectivity/feather_berts_{model_number}'

##################################

PEF_FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILEPATH)
NMF_FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILEPATH)

FISHER_FILEPATH = os.path.join(FISHERS_DIR, FISHER_FILEPATH)

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

MODEL_NUMBER = 0
# COMP_INDEX = 82

##########################################################################

hacc = HCC.HansLoneComponentContext(
    model_name_pattern=MODEL_PATTERN,
    pef_filepath_pattern=PEF_FILEPATH,
    nmf_filepath_pattern=NMF_FILEPATH,
    with_flipped=True,
    tokenizer=tokenizer,
)

mc = hacc.make_model_context('a', model_number=MODEL_NUMBER, n_fisher_values=65536)

##########################################################################

# Corrects: 23, 24, 36, 43, 50, 69, 82, 85, 86, 112, 138
# Incorrects: 30, 34, 56, 81, 94, 97, 113, 122, 124
# Mixed: 42, 57

device = int(os.environ['CUDA_VISIBLE_DEVICES'])

CORRECT_COMP_INDS = [23, 24, 36, 43, 50, 69, 82, 85, 86, 112, 138][device::4]
INCORRECT_COMP_INDS = [30, 34, 56, 81, 94, 97, 113, 122, 124][device::4]

ALL_COMP_INDS = CORRECT_COMP_INDS + INCORRECT_COMP_INDS
random.shuffle(ALL_COMP_INDS)

##########################################################################

for comp_ind in ALL_COMP_INDS:
    print(comp_ind)

    delta_sign = 1 if comp_ind in CORRECT_COMP_INDS else -1
    
    config = HAE.AblationExperimentConfig(
        model_number=MODEL_NUMBER,
        component_index=comp_ind,
        n_mnli_examples=3072,
        n_top_examples_sign_guide=12,
        n_coefficients=21,
    )

    exp = HAE.AblationExperiment(
        config=config,
        tokenizer=tokenizer,
        hacc=hacc,
        mc=mc,
    )

    # print('Starting Run 1')
    # delta = delta_sign * 3.5e-4
    # results1 = exp.perform_merge_based_ablation_run(
    #     retaining_fisher=exp.full_dataset_fisher,
    #     ablating_fisher=exp.component_h_fisher,
    #     delta=delta,
    # )

    # HAE.save_results_list_to_h5(
    #     os.path.join(EXPS_DIR, f'hans_ablate_results.mm_1.model{MODEL_NUMBER}.comp{comp_ind}.h5'),
    #     results1,
    #     exp.make_metadata_for_saving(
    #         ablation_type='model_merging',
    #         retaining_fisher='full_dataset_fisher',
    #         ablating_fisher='component_h_fisher',
    #         delta=delta,
    #     ),
    # )

    print('Starting Run 2')
    # delta = delta_sign * 3.5e-4
    delta = delta_sign * 2.5e-5
    results2 = exp.perform_merge_based_ablation_run(
        retaining_fisher=exp.full_dataset_fisher,
        ablating_fisher=exp.top_component_examples_full_fisher,
        delta=delta,
    )

    HAE.save_results_list_to_h5(
        os.path.join(EXPS_DIR, f'hans_ablate_results.mm_2.model{MODEL_NUMBER}.comp{comp_ind}.h5'),
        results2,
        exp.make_metadata_for_saving(
            ablation_type='model_merging',
            retaining_fisher='full_dataset_fisher',
            ablating_fisher='top_component_examples_full_fisher',
            delta=delta,
        ),
    )

    # print('Starting Run 3')
    # delta = delta_sign * 3.5e-2
    # results3 = exp.perform_gradient_following_run(
    #     delta=delta,
    # )
    # HAE.save_results_list_to_h5(
    #     os.path.join(EXPS_DIR, f'hans_ablate_results.gf_1.model{MODEL_NUMBER}.comp{comp_ind}.h5'),
    #     results3,
    #     exp.make_metadata_for_saving(
    #         ablation_type='gradient_following',
    #         delta=delta,
    #     ),
    # )
