# %%
from video_blip.data.video_entity_dataset import VideoEntityDataset
from dataclasses import dataclass
import transformers

# %%
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="1"

# %%
@dataclass
class ModelArguments:
    language_model: str = "google-bert/bert-base-uncased"
    # device: str = "cuda:0"
    device: str = "cpu"
    

# %%
@dataclass
class DataArguments:
    video_dir: str = "data/video/"
    train_csv_dir: str = "data/data_train_sample.csv"
    val_csv_dir: str = "data/data_val_sample.csv"
# %%
from transformers import TrainingArguments



from transformers import EarlyStoppingCallback
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=1)
# %%
# parser = transformers.HfArgumentParser(
#     (ModelArguments, DataArguments, TrainingArguments)
# )
# model_args: ModelArguments
# data_args: DataArguments
# training_args: TrainingArguments
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()

# %%
data_args = DataArguments()
model_args = ModelArguments()

training_args = TrainingArguments(
    output_dir=f"output/{model_args.language_model}_complaint",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs = 10,
    save_strategy = "steps",
    save_steps = 1000,
    eval_steps = 100,
    evaluation_strategy = "steps",
    report_to = "none",
    logging_steps = 4,
    disable_tqdm = False,
    metric_for_best_model ="f1",
    load_best_model_at_end = True,
    no_cuda = True if model_args.device == "cpu" else False,
    # label_names = ["aspect"]
)

# %%
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
tokenizer = AutoTokenizer.from_pretrained(model_args.language_model)

# %%
tokenizer

# %%
def preprocess(item):
    # cleaned_transcript_text = item["transcript"].strip()
    preprocessed = tokenizer(
        item["transcript"], 
        padding="max_length",
        truncation=True,
        max_length=128,
    )
    preprocessed["labels"] = item["labels"]

    return preprocessed

# %%
# tokenizer("transcript", padding=True, truncation=True, max_length=512)

# %%
train_data = VideoEntityDataset(
    data_args.video_dir,
    data_args.train_csv_dir,
    transform=preprocess
)

val_data = VideoEntityDataset(
    data_args.video_dir,
    data_args.val_csv_dir,
    transform=preprocess
)

# %%
train_data[0].keys()

# %%
from transformers import Trainer

# %%
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average='micro')
    precision = precision_score(y_true=labels, y_pred=pred, average='micro')
    f1 = f1_score(y_true=labels, y_pred=pred, average='micro')   
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

# %%
# model = model.to("cpu")

# %%
config = AutoConfig.from_pretrained(model_args.language_model, num_labels = 7, label2id=VideoEntityDataset.label2id, id2label=VideoEntityDataset.id2label)
model = AutoModelForSequenceClassification.from_pretrained(model_args.language_model, config=config, ignore_mismatched_sizes=True)
# import pdb; pdb.set_trace()

# %%
from transformers import DataCollatorWithPadding
import torch

class DataCollatorForVideoClassfication(DataCollatorWithPadding):
    def __call__(self, features, return_tensors="pt"):

        labels = torch.tensor(
            [feature.pop("labels") for feature in features]
        )

        collated = super().__call__(features)
        collated["labels"] = labels

        return collated

# %%
print("Training Arguments")
print(training_args)

# %%
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=DataCollatorForVideoClassfication(tokenizer=tokenizer),
    train_dataset=train_data,
    eval_dataset=val_data,
    callbacks=[early_stopping_callback],
    compute_metrics=compute_metrics
)

# %%
print(f"MODEL PARAMETERS: {sum(p.numel() for p in model.parameters())}")

# %%
trainer.train()


