import os
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, DataCollatorWithPadding
import gc
from transformers import AutoConfig
import torch
import numpy as np
import os
import argparse
from transformers import  BertForSequenceClassification
import sys
sys.path.append("util/")
from finetuneUtils import *


os.environ["WANDB_MODE"] = "disabled"


parser = argparse.ArgumentParser(description="Training script run number.")
parser.add_argument('--runNumber', type=int, required=True, help='The run number to identify this run.')
parser.add_argument('--task', type=str, required=True, help='The task for the run.')
args = parser.parse_args()



torch.cuda.empty_cache()
gc.collect()
k = args.runNumber
task = args.task

for task_name in TASKS:

    #############################
    ######## train the model ####
    #############################

    if(task_name != task ):
        # rte vallidation set is balanced: 
        # # https://huggingface.co/datasets/evaluate/glue-ci/viewer/rte/validation?views%5B%5D=rte_validation
        # for now only use this one
        continue

    
    save_directory = f"Finetuning/models/{task_name}_{k}"

    num_labels = get_num_labels(task_name)
    
    model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

    data_collator = DataCollatorWithPadding(tokenizer)

    train_ds, validation_ds, test_ds = get_datasets(task_name)

    trainer = finetune(task_name, train_ds, validation_ds, model=model, data_collator=data_collator)

    model.save_pretrained(save_directory)

    torch.cuda.empty_cache()
    gc.collect()


    #############################
    ## Save SVD decomposition ###
    #############################


    # if any of these strings is in the key, it is a weight matrix
    validWeighM = ["attention.self.query.weight", "attention.self.key.weight","attention.self.value.weight" ,"attention.output.dense.weight", "intermediate.dense.weight", "output.dense.weight"]

    stateDic = model.state_dict()

    # location for the final model
    pathSVDmodel =  f"Finetuning/BertSVD_{k}/"

    if( not os.path.isdir(pathSVDmodel) ):
        os.makedirs(pathSVDmodel)

    # save a model in its svd
    for key in stateDic.keys():
        if( any([validM in key for validM in validWeighM])):
            saveMa(f"{pathSVDmodel}/{key}", stateDic[key])

    ###################################
    ## Evaluate removal Performance ###
    ###################################

    config = AutoConfig.from_pretrained(save_directory)
    # Load the model from the checkpoint
    model = BertForSequenceClassification.from_pretrained(save_directory, config=config)

    trainer = Trainer(
        model=model,
        train_dataset=train_ds,
        eval_dataset=validation_ds,
        data_collator=data_collator
    )
    base_score = compute_score(task_name, validation_ds, trainer)

    print("base score: ", base_score)

    accuracy = [base_score]

    # loop over different removal steps
    for remInd in range(10):

        # Load the model from the checkpoint
        model = BertForSequenceClassification.from_pretrained(save_directory, config=config)

        stateDic = model.state_dict()


        for key in stateDic.keys():
            if( any([validM in key for validM in validWeighM])):
                
                U = torch.load(f"{pathSVDmodel}/{key}_U", weights_only=True)
                S = torch.load(f"{pathSVDmodel}/{key}_S", weights_only=True)
                Vh = torch.load(f"{pathSVDmodel}/{key}_Vh", weights_only=True)
                # remove a fraction of the Sval
                S[remInd*len(S)//10: (remInd+1)*len(S)//10] = 0

                reconstrMa = U @ torch.diag(S) @ Vh
                stateDic[key] = reconstrMa

        model.load_state_dict(stateDic)


        model.to("cuda")

        trainer = Trainer(
            model=model,
            train_dataset=train_ds,
            eval_dataset=validation_ds,
            data_collator=data_collator
        )

        finalScore = compute_score(task_name, validation_ds, trainer)


        accuracy.append( finalScore )


    np.save(f"Data/FinetuningResults/Bert_{task_name}_decentile_removal_{k}",accuracy )