
import os
import pydub
import pandas as pd
import transformers
import torch
from torch.utils.data import DataLoader
from dataclasses import dataclass
from src.models import load_model
from data.video_dataset import VideoEntityDataset
from transformers import HfArgumentParser, TrainingArguments, Trainer, EarlyStoppingCallback
from src.models import load_preprocess
from transformers import WhisperFeatureExtractor

transformers.logging.set_verbosity_info()

@dataclass
class DataArguments:
    video_dir: str = "data/video/"
    audio_dir: str = "data/audio/"
    train_csv_dir: str = "data/data_train_sample.csv"
    val_csv_dir: str = "data/data_val_sample.csv"
    intervals_path: str = "data/predicted_time_intervals.csv"
    device: str = "cuda"
    
@dataclass
class ModelArguments:
    run_name: str = "maaca_pretrain"
    audio_transform_hidden_dim: int = 768
    video_transform_hidden_dim: int = 768
    audio_transform_num_layers: int = 2
    video_transform_num_layers: int = 2
    audio_output_seq_len: int = 128
    video_output_seq_len: int = 128
    audio_transform_output_dim: int = 768
    video_transform_output_dim: int = 768
    fusion_output_dim: int = 768
    linear_layer_hidden_dim: int = 64
    add_pooling: bool = False
    num_classes: int = 7
    max_txt_len: int = 128

parser = HfArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parser.parse_args_into_dataclasses()

vis_processor, txt_processor = load_preprocess("alpro_retrieval2", model_type="msrvtt")
audio_processor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

train_data = VideoEntityDataset(
    data_args.video_dir,
    data_args.audio_dir,
    data_args.train_csv_dir,
    data_args.intervals_path,
    txt_processor["train"],
    vis_processor["train"],
    audio_processor
)

val_data = VideoEntityDataset(
    data_args.video_dir,
    data_args.audio_dir,
    data_args.val_csv_dir,
    data_args.intervals_path,
    txt_processor["eval"],
    vis_processor["eval"],
    audio_processor
)

model = load_model("alpro_retrieval2", model_type="msrvtt", device=data_args.device)
print(f"MODEL:\n{model}")

class DataCollatorForVideoClassfication():
    def __call__(self, features, return_tensors="pt"):
        aspect = torch.tensor(
            [feature.pop("aspect") for feature in features]
        )
        complaint = torch.tensor(
            [feature.pop("complaint") for feature in features]
        )
        video = torch.stack(
            [feature.pop("video") for feature in features]
        )
        audio = torch.stack(
            [feature.pop("audio") for feature in features]
        )
        text_input = [feature.pop("text_input") for feature in features]

        return {"video": video, "audio": audio, "text_input": text_input, "aspect":aspect, "complaint":complaint}

early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=1)

training_args = TrainingArguments(
    output_dir=os.path.join("output", model_args.run_name),
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=10,
    save_strategy="steps",
    save_steps=1000,
    do_eval=False,
    report_to="none",
    no_cuda= True if data_args.device == "cpu" else False,
    label_names=["aspect"],
    logging_steps=10,
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=DataCollatorForVideoClassfication(),
    train_dataset=train_data,
)

print(f"MODEL DEVICE:{model.device}")
trainer.train()
trainer.save_model(training_args.output_dir)



