import os
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, DataCollatorWithPadding
import gc
from transformers import AutoConfig
import torch
import numpy as np
import os
import argparse
from transformers import  BertForSequenceClassification
import sys
sys.path.append("util/")
from finetuneUtils import *


os.environ["WANDB_MODE"] = "disabled"


parser = argparse.ArgumentParser(description="Training script run number.")
parser.add_argument('--runNumber', type=int, required=True, help='The run number to identify this run.')
parser.add_argument('--task', type=str, required=True, help='The task for the run.')
args = parser.parse_args()



torch.cuda.empty_cache()
gc.collect()
k = args.runNumber
task = args.task

for task_name in TASKS:

    #############################
    ######## train the model ####
    #############################

    if(task_name !=task ):
        # rte vallidation set is balanced: 
        # # https://huggingface.co/datasets/evaluate/glue-ci/viewer/rte/validation?views%5B%5D=rte_validation
        # for now only use this one
        continue

            
    save_directory = f"Finetuning/models/FullModel_{task_name}"

    num_labels = get_num_labels(task_name)
    
    model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    
    #############################
    ## Save SVD decomposition ###
    #############################


    # if any of these strings is in the key, it is a weight matrix
    validWeighM = ["attention.self.query.weight", "attention.self.key.weight","attention.self.value.weight" ,"attention.output.dense.weight", "intermediate.dense.weight", "output.dense.weight"]

    stateDic = model.state_dict()

    # location for the final model
    pathSVDmodel =  f"Finetuning/BertSVD_{k}/"

    if( not os.path.isdir(pathSVDmodel) ):
        os.makedirs(pathSVDmodel)

    # save a model in its svd
    for key in stateDic.keys():
        if( any([validM in key for validM in validWeighM])):
            saveMa(f"{pathSVDmodel}/{key}", stateDic[key])


    data_collator = DataCollatorWithPadding(tokenizer)

    train_ds, validation_ds, test_ds = get_datasets(task_name)

    torch.cuda.empty_cache()
    gc.collect()

    ###################################
    ## Evaluate removal Performance ###
    ###################################

    accuracy = []

    # loop over different removal steps
    for remInd in range(10):
        
        model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    
        stateDic = model.state_dict()


        for key in stateDic.keys():
            if( any([validM in key for validM in validWeighM])):
                
                U = torch.load(f"{pathSVDmodel}/{key}_U", weights_only=True)
                S = torch.load(f"{pathSVDmodel}/{key}_S", weights_only=True)
                Vh = torch.load(f"{pathSVDmodel}/{key}_Vh", weights_only=True)
                # remove a fraction of the Sval
                S[remInd*len(S)//10: (remInd+1)*len(S)//10] = 0

                reconstrMa = U @ torch.diag(S) @ Vh
                stateDic[key] = reconstrMa

        model.load_state_dict(stateDic)


        model.to("cuda")

        trainer = finetune(task_name, train_ds, validation_ds, model=model, data_collator=data_collator)

        finalScore = compute_score(task_name, validation_ds, trainer)
        print(remInd, finalScore)


        accuracy.append( finalScore )


    np.save(f"Data/FinetuningResults/Bert_{task_name}_decentile_removeFirst_{k}",accuracy )