
import transformers
import os
import pydub
import torch
import wandb
import numpy as np
import pandas as pd
from typing import Any
from dataclasses import dataclass
from torch.utils.data import Dataset
from pytorchvideo.data.video import VideoPathHandler
from transformers import WhisperFeatureExtractor, TrainingArguments, Trainer, HfArgumentParser, TrainingArguments, EarlyStoppingCallback
from src.models import load_model, load_preprocess, load_model_from_config
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from data.video_dataset import VideoEntityDataset

transformers.logging.set_verbosity_warning()

@dataclass
class DataArguments:
    video_dir: str = "data/video/"
    audio_dir: str = "data/audio/"
    train_csv_dir: str = "data/data_train_sample.csv"
    val_csv_dir: str = "data/data_val_sample.csv"
    intervals_path: str = "data/predicted_time_intervals_sample.csv"
    device: str = "cuda"

@dataclass
class ModelArguments:
    run_name: str = "maaca"
    audio_transform_hidden_dim: int = 768
    video_transform_hidden_dim: int = 768
    audio_transform_num_layers: int = 2
    video_transform_num_layers: int = 2
    audio_output_seq_len: int = 128
    video_output_seq_len: int = 128
    audio_transform_output_dim: int = 768
    video_transform_output_dim: int = 768
    fusion_output_dim: int = 768
    linear_layer_hidden_dim: int = 64
    add_pooling: bool = False
    num_classes: int = 7
    max_txt_len: int = 128

parser = HfArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parser.parse_args_into_dataclasses()

training_args = TrainingArguments(
    output_dir = os.path.join("output", model_args.run_name),
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    num_train_epochs = 15,
    save_strategy = "steps",
    save_steps = 1000,
    eval_steps = 100,
    evaluation_strategy = "steps",
    report_to = "none",
    logging_steps = 10,
    disable_tqdm = False,
    metric_for_best_model ="eval_aspect_f1", # modify to "eval_complaint_f1" for stoppig based on complaint metrics 
    load_best_model_at_end = True,
    no_cuda = True if data_args.device == "cpu" else False,
    label_names = ["complaint", "aspect"],
    save_total_limit = 1
)


vis_processor, txt_processor = load_preprocess("alpro_qa_audio2", model_type="product2")
audio_processor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

train_data = VideoEntityDataset(
    data_args.video_dir,
    data_args.audio_dir,
    data_args.train_csv_dir,
    data_args.intervals_path,
    txt_processor["train"],
    vis_processor["train"],
    audio_processor
)

val_data = VideoEntityDataset(
    data_args.video_dir,
    data_args.audio_dir,
    data_args.val_csv_dir,
    data_args.intervals_path,
    txt_processor["eval"],
    vis_processor["eval"],
    audio_processor
)

model = load_model_from_config("alpro_qa_audio2", model_type="product2", cmd_config=model_args, device=data_args.device)
print(f"MODEL:\n{model}")

def compute_metrics(p):
    label_ids = p.label_ids
    predictions = p.predictions[1]
    stats = {}

    for pred, labels in zip(predictions, label_ids):
        entity = "complaint" if pred.shape[1] == 2 else "aspect"
        pred = np.argmax(pred, axis=1)
        labels = np.expand_dims(labels, axis=1)
        accuracy = accuracy_score(y_true=labels, y_pred=pred)
        recall = recall_score(y_true=labels, y_pred=pred, average='micro')
        precision = precision_score(y_true=labels, y_pred=pred, average='micro')
        f1 = f1_score(y_true=labels, y_pred=pred, average='micro')    

        stat = {f"{entity}_accuracy": accuracy, 
                f"{entity}_precision": precision, 
                f"{entity}_recall": recall, 
                f"{entity}_f1": f1}
        stats.update(stat)
    
    return stats

class DataCollatorForVideoClassfication():
    def __call__(self, features, return_tensors="pt"):
        aspect = torch.tensor(
            [feature.pop("aspect") for feature in features]
        )
        complaint = torch.tensor(
            [feature.pop("complaint") for feature in features]
        )
        video = torch.stack(
            [feature.pop("video") for feature in features]
        )
        audio = torch.stack(
            [feature.pop("audio") for feature in features]
        )
        text_input = [feature.pop("text_input") for feature in features]

        return {"video": video, "audio": audio, "text_input": text_input, "aspect":aspect, "complaint":complaint}

early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=2)

if (training_args.report_to == "wandb"):   
    wandb.init(name=model_args.run_name)
    
print(f"MODEL ARGUMENTS:\n{model_args}")

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=DataCollatorForVideoClassfication(),
    train_dataset=train_data,
    eval_dataset=val_data,
    compute_metrics=compute_metrics,   
    callbacks=[early_stopping_callback]  
)

print(f"MODEL DEVICE:{model.device}")
trainer.train()
metrics = trainer.evaluate()
trainer.save_metrics("eval", metrics)
trainer.save_model(training_args.output_dir)
