import torch
from transformers import BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import numpy as np



TASKS = [ "rte", "sst2", "mrpc", "stsb", 
        "cola", "wnli", "mnli_matched",
        "mnli_mismatched",    "ax",
         "qnli",      "qqp",
         ]

FEATURES = {
    "cola": ("sentence",),
    "sst2": ("sentence",),
    "mrpc": ("sentence1", "sentence2"),
    "stsb": ("sentence1", "sentence2"),
    "rte": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
    "mnli": ("premise", "hypothesis"),
    "mnli_matched": ("premise", "hypothesis"),
    "mnli_mismatched": ("premise", "hypothesis"),
    "ax": ("premise", "hypothesis"),
    "qnli": ("question", "question"),
    "qqp": ("question1", "question2"),
}

filenames = {
    "cola": "CoLA.tsv",
    "sst2": "SST-2.tsv",
    "mrpc": "MRPC.tsv",
    "qqp": "QQP.tsv",
    "stsb": "STS-B.tsv",
    "mnli_matched": "MNLI-m.tsv",
    "mnli_mismatched": "MNLI-mm.tsv",
    "qnli": "QNLI.tsv",
    "rte": "RTE.tsv",
    "wnli": "WNLI.tsv",
    "ax": "AX.tsv",
}

labelnames = {
    "mnli_matched": ["entailment", "neutral", "contradiction"],
    "mnli_mismatched": ["entailment", "neutral", "contradiction"],
    "ax": ["entailment", "neutral", "contradiction"],
    "qnli": ["entailment", "not_entailment"],
    "rte": ["entailment", "not_entailment"],
}

# globals of the code....
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
PADDING = 'max_length'
TRUNCATION = 'longest_first'
CUDA = torch.cuda.is_available()  # Whether to use GPU or CPU





def get_datasets(task_name):
  if task_name in ("ax", "mnli_matched", "mnli_mismatched"):
    dataset = load_dataset("glue", "mnli")
    train_ds, validation_ds = dataset["train"], dataset["validation_matched"]

    if task_name == "ax":
      dataset = load_dataset("glue", "ax")
      test_ds = dataset["test"]

    elif task_name == "mnli_matched":
      dataset = load_dataset("glue", "mnli_matched")
      validation_ds, test_ds = dataset["validation"], dataset["test"]

    elif task_name == "mnli_mismatched":
      dataset = load_dataset("glue", "mnli_mismatched")
      validation_ds, test_ds = dataset["validation"], dataset["test"]

  elif task_name != "mnli":
    dataset = load_dataset("glue", task_name)
    train_ds, validation_ds, test_ds = dataset["train"], dataset["validation"], dataset["test"]

  return train_ds, validation_ds, test_ds

def preprocess_input_features(sample, task_name):
    if len(FEATURES[task_name]) == 1:
        input_text = sample[FEATURES[task_name][0]]
    elif len(FEATURES[task_name]) == 2:
        input_text = f"task_name {FEATURES[task_name][0]}: {sample[FEATURES[task_name][0]]}, {FEATURES[task_name][0]}: {sample[FEATURES[task_name][1]]}"

    return {'input_text' : input_text}

def preprocess_function(examples, task_name):
    max_length = 512
    tokenized = tokenizer(examples['input_text'], truncation=TRUNCATION, padding=PADDING, max_length=max_length, return_tensors="pt")
    return tokenized

def get_num_labels(task_name):
    if task_name in ("ax", "mnli_matched", "mnli_mismatched"):
        return 3
    elif task_name == "stsb":
        return 1
    else:
        return 2

def finetune(task_name, train_ds, validation_ds, model, data_collator):
    print("▶ Starting Finetuning... ")

    # Tokenize datasets
    print("\t  ▶ Preparing datasets...")
    train_ds = train_ds.map(lambda x: preprocess_input_features(x, task_name), batched=False)
    validation_ds = validation_ds.map(lambda x: preprocess_input_features(x, task_name), batched=False)
    train_ds = train_ds.map(lambda x: preprocess_function(x, task_name), batched=True)
    validation_ds = validation_ds.map(lambda x: preprocess_function(x, task_name), batched=True)
    print("\t  ▶ End of Preparing datasets.")

    # Format datasets for PyTorch
    train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    validation_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

    # Training Arguments
    training_args = TrainingArguments(
        output_dir=f'Finetuning/results/{task_name}',
        evaluation_strategy="no",
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        logging_dir='Finetuning/logs',
        learning_rate=2e-5,
        num_train_epochs=3, 
        run_name="BertTrain"
    )

    # Device setup
    device = torch.device("cuda" if CUDA else "cpu")
    model.to(device)

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=validation_ds,
        data_collator=data_collator
    )

    # Start training
    print("\t ▶ Starting Training")
    trainer.train()
    print("\t ▶ Finished  Training")

    return trainer

def get_predictions(task_name, test_ds, trainer):
    print(" ▶ Starting Test Predictions ")

    labelname = labelnames.get(task_name)

    # Tokenize test dataset
    test_ds = test_ds.map(lambda x: preprocess_input_features(x, task_name), batched=False)

    test_ds = test_ds.map(lambda x: preprocess_function(x, task_name), batched=True)
    test_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # Get predictions
    predictions = trainer.predict(test_ds)
    pred_labels = torch.argmax(torch.tensor(predictions.predictions), dim=-1)

    predictions_list = []
    for pred in pred_labels:
        if labelname:
            pred_label = labelname[int(pred)]
            predictions_list.append(pred_label)

        elif task_name == "stsb":
            pred_label = min(max(pred.item(), 0), 5)  # For tasks like STSB
            pred_label = f"{pred_label:.3f}"
            predictions_list.append(pred_label)
        else:
            predictions_list.append(int(pred))


    print(" ▶ End of Generating Predictions ")

def compute_score(task_name, test_ds, trainer):  

    labelname = labelnames.get(task_name)

    # Tokenize test dataset
    test_ds = test_ds.map(lambda x: preprocess_input_features(x, task_name), batched=False)

    test_ds = test_ds.map(lambda x: preprocess_function(x, task_name), batched=True)
    test_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # Get predictions
    predictions = trainer.predict(test_ds)
    pred_labels = torch.argmax(torch.tensor(predictions.predictions), dim=-1)

    correct = 0
    wrong = 0
    predictions_list = []
    for ind,pred in enumerate(pred_labels):

        if(test_ds["label"][ind] == int(pred)):
            correct += 1
        else:
            wrong += 1


    finalScore = 0
    if(wrong+correct > 0):
        finalScore = correct/(correct+wrong)
        print(f"final acc of: {finalScore}")

    return finalScore

def evaluate_model(trainer, validation_ds):
    eval_results = trainer.evaluate(eval_dataset=validation_ds)
    print("Evaluation results:", eval_results)
    return eval_results

def saveMa( name , matrix):

    # compute in numpy for better precision
    npMa = matrix.detach().cpu().numpy()
    u ,s , vh = np.linalg.svd(npMa, full_matrices=False)
    U , Sval , Vh = torch.from_numpy(u), torch.from_numpy(s) , torch.from_numpy(vh)

    # save the file
    torch.save(U,f"{name}_U")
    torch.save(Sval ,f"{name}_S")
    torch.save(Vh ,f"{name}_Vh")
