# %%
# label2id = {'camera':0, 'os':1, 'design':2, 'battery':3, 'price':4, 'speaker':5, 'storage':6}
# id2label = {0:'camera', 1:'os', 2:'design', 3:'battery', 4:'price', 5:'speaker', 6:'storage'}
label2id = {'complaint':0, 'non-complaint':1}
id2label = {0:'conmplaint', 1:'non-complaint'}

# %%
model_ckpt = "facebook/timesformer-base-finetuned-k400" # pre-trained model from which to fine-tune

# %%
from transformers import AutoImageProcessor, TimesformerForVideoClassification


image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
model = TimesformerForVideoClassification.from_pretrained(
    model_ckpt,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

# %% [markdown]
# The warning is telling us we are throwing away some weights (e.g. the weights and bias of the `classifier` layer) and randomly initializing some other (the weights and bias of a new `classifier` layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

# %% [markdown]
# **Note** that [this checkpoint](https://huggingface.co/MCG-NJU/videomae-base-finetuned-kinetics) leads to better performance on this task as the checkpoint was obtained fine-tuning on a similar downstream task having considerable domain overlap. You can check out [this checkpoint](https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset) which was obtained by fine-tuning `MCG-NJU/videomae-base-finetuned-kinetics` and it obtains much better performance.  

# %% [markdown]
# ### Constructing the datasets for training

# %% [markdown]
# For preprocessing the videos, we'll leverage the [PyTorch Video library](https://pytorchvideo.org/). We start by importing the dependencies we need.

# %%
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)

# %% [markdown]
# For the training dataset transformations, we use a combination of uniform temporal subsampling, pixel normalization, random cropping, and random horizontal flipping. For the validation and evaluation dataset transformations, we keep the transformation chain the same except for random cropping and horizontal flipping. To learn more about the details of these transformations check out the [official documentation of PyTorch Video](https://pytorchvideo.org).  
# 
# We'll use the `image_processor` associated with the pre-trained model to obtain the following information:
# 
# * Image mean and standard deviation with which the video frame pixels will be normalized.
# * Spatial resolution to which the video frames will be resized.

# %%
import os
from video_blip.data.video_entity_dataset import VideoEntityDataset


mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps


# Training dataset transformations.
train_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                    RandomHorizontalFlip(p=0.5),
                ]
            ),
        ),
    ]
)

    
# Training dataset.
train_dataset = VideoEntityDataset(
    "data/video/",
    "data/data_train_sample.csv",
    transform=train_transform
)

# train_dataset = pytorchvideo.data.Ucf101(
#     data_path=os.path.join(dataset_root_path, "train"),
#     clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
#     decode_audio=False,
#     transform=train_transform,
# )

# Validation and evaluation datasets' transformations.
val_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to),
                ]
            ),
        ),
    ]
)


# Validation and evaluation datasets.
val_dataset = test_dataset = VideoEntityDataset(
    "data/video/",
    "data/data_val_sample.csv",
    transform=val_transform
)

# %% [markdown]
# **Note**: The above dataset pipelines are taken from the [official PyTorch Video example](https://pytorchvideo.org/docs/tutorial_classification#dataset). We're using the [`pytorchvideo.data.Ucf101()`](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.Ucf101) function because it's tailored for the UCF-101 dataset. Under the hood, it returns a [`pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset`](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.LabeledVideoDataset) object. `LabeledVideoDataset` class is the base class for all things video in the PyTorch Video dataset. So, if you wanted to use a custom dataset not supported off-the-shelf by PyTorch Video, you can extend the `LabeledVideoDataset` class accordingly. Refer to the `data` API [documentation to](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html) learn more. Also, if your dataset follows a similar structure (as shown above), then using the `pytorchvideo.data.Ucf101()` should work just fine.

# %%
# We can access the `num_videos` argument to know the number of videos we have in the
# dataset.
# train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos

# %% [markdown]
# Let's now take a preprocessed video from the dataset and investigate it.

# %%
# sample_video = next(iter(train_dataset))
# sample_video.keys()

# %%
# sample_video["labels"]

# %%
# def investigate_video(sample_video):
#     """Utility to investigate the keys present in a single video sample."""
#     for k in sample_video:
#         if k == "video":
#             print(k, sample_video["video"].shape)
#         else:
#             print(k, sample_video[k])

#     print(f"Video label: {id2label[sample_video[k]]}")


# investigate_video(sample_video)

# %% [markdown]
# We can also visualize the preprocessed videos for easier debugging.

# %%
import imageio
import numpy as np
from IPython.display import Image


def unnormalize_img(img):
    """Un-normalizes the image pixels."""
    img = (img * std) + mean
    img = (img * 255).astype("uint8")
    return img.clip(0, 255)


def create_gif(video_tensor, filename="sample.gif"):
    """Prepares a GIF from a video tensor.

    The video tensor is expected to have the following shape:
    (num_frames, num_channels, height, width).
    """
    frames = []
    for video_frame in video_tensor:
        frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())
        frames.append(frame_unnormalized)
    kargs = {"duration": 0.25}
    imageio.mimsave(filename, frames, "GIF", **kargs)
    return filename


def display_gif(video_tensor, gif_name="sample.gif"):
    """Prepares and displays a GIF from a video tensor."""
    video_tensor = video_tensor.permute(1, 0, 2, 3)
    gif_filename = create_gif(video_tensor, gif_name)
    return Image(filename=gif_filename)

# %%
# video_tensor = sample_video["video"]
# display_gif(video_tensor)

# %% [markdown]
# ### Training the model

# %% [markdown]
# We'll leverage [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer) from  🤗 Transformers for training the model. To instantiate a `Trainer`, we will need to define the training configuration and an evaluation metric. The most important is the [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments), which is a class that contains all the attributes to configure the training. It requires an output folder name, which will be used to save the checkpoints of the model. It also helps sync all the information in the model repository on 🤗 Hub.
# 
# Most of the training arguments are pretty self-explanatory, but one that is quite important here is `remove_unused_columns=False`. This one will drop any features not used by the model's call function. By default it's `True` because usually it's ideal to drop unused feature columns, making it easier to unpack inputs into the model's call function. But, in our case, we need the unused features ('video' in particular) in order to create `pixel_values` (which is a mandatory key our model expects in its inputs).

# %%
from transformers import TrainingArguments, Trainer

model_name = model_ckpt.split("/")[-1]
new_model_name = f"{model_name}-finetuned-complaint"
num_epochs = 4

# %%
# len(train_dataset)

# %%
batch_size = 2 # batch size for training and evaluation

args = TrainingArguments(
    new_model_name,
    remove_unused_columns=False,
    # eval_steps=2,
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    logging_steps=25,
    num_train_epochs=10,
    # load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    # push_to_hub=True,
    max_steps=(len(train_dataset) // batch_size) * num_epochs,
    # no_cuda=True,
    report_to="none",
    optim="adafactor",
    # eval_steps=100,
    evaluation_strategy="epoch",
    load_best_model_at_end = True,
)

# %% [markdown]
# There's no need to define `max_steps` when instantiating `TrainingArguments`. Since the dataset returned by `pytorchvideo.data.Ucf101()` doesn't implement the `__len__()` method we had to specify `max_steps`.  

# %% [markdown]
# Next, we need to define a function for how to compute the metrics from the predictions, which will just use the `metric` we'll load now. The only preprocessing we have to do is to take the argmax of our predicted logits:

# %%
import evaluate

metric = evaluate.load("accuracy")

# %%
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions."""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    ground_truth = eval_pred.label_ids
    print(predictions, ground_truth)
    return metric.compute(predictions=predictions, references=ground_truth)

# %% [markdown]
# **A note on evaluation**:
# 
# In the [VideoMAE paper](https://arxiv.org/abs/2203.12602), the authors use the following evaluation strategy. They evaluate the model on several clips from test videos and apply different crops to those clips and report the aggregate score. However, in the interest of simplicity and brevity, we don't consider that in this tutorial.

# %% [markdown]
# We also define a `collate_fn`, which will be used to batch examples together.
# Each batch consists of 2 keys, namely `pixel_values` and `labels`.

# %%
import torch



def collate_fn(examples):
    """The collation function to be used by `Trainer` to prepare data batches."""
    # permute to (num_frames, num_channels, height, width)
    pixel_values = torch.stack(
        [example["video"].permute(1, 0, 2, 3) for example in examples]
    )
    # labels = torch.tensor([label2id[example["labels"]] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# %%
# from video_blip.data.utils import DataCollatorForVideoClassfication

# %% [markdown]
# Then we just need to pass all of this along with our datasets to the `Trainer`:

# %%
# device = torch.device("cuda", 1)
# device

# %%
model = model.to(args.device)

# %%
print(f"## Model Device: {model.device}")

# %%
# from torch.optim import SGD

# %%
# args.device
print(f"## Args Device: {args.device}")

from transformers import EarlyStoppingCallback
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=1)
# %%
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    callbacks=[early_stopping_callback],

    # optimizers=(SGD(model.parameters(), lr=0.0001), None)
)

# %% [markdown]
# You might wonder why we pass along the `image_processor` as a tokenizer when we already preprocessed our data. This is only to make sure the feature extractor configuration file (stored as JSON) will also be uploaded to the repo on the hub.

# %% [markdown]
# Now we can finetune our model by calling the `train` method:

# %%
# !wandb login --relogin

# %%
print("Training Started!")
train_results = trainer.train()

# %% [markdown]
# We can check with the `evaluate` method that our `Trainer` did reload the best model properly (if it was not the last one):

# %%
# trainer.evaluate(test_dataset)

# %%
trainer.save_model()
test_results = trainer.evaluate(test_dataset)
trainer.log_metrics("test", test_results)
trainer.save_metrics("test", test_results)
trainer.save_state()

# %% [markdown]
# You can now upload the result of the training to the Hub, just execute this instruction (note that the Trainer will automatically create a model card as well as Tensorboard logs - see the "Training metrics" tab - amazing isn't it?):

# %%
# trainer.push_to_hub()

# %% [markdown]
# Now that our model is trained, let's use it to run inference on a video from `test_dataset`.

# %% [markdown]
# ## Inference

# %% [markdown]
# Let's load the trained model checkpoint and fetch a video from `test_dataset`.

# %%
# trained_model = VideoMAEForVideoClassification.from_pretrained(new_model_name)

# %%
# sample_test_video = next(iter(test_dataset))
# investigate_video(sample_test_video)

# %% [markdown]
# We then prepare the video as a `torch.Tensor` and run inference.

# %%
# def run_inference(model, video):
#     """Utility to run inference given a model and test video.

#     The video is assumed to be preprocessed already.
#     """
#     # (num_frames, num_channels, height, width)
#     perumuted_sample_test_video = video.permute(1, 0, 2, 3)

#     inputs = {
#         "pixel_values": perumuted_sample_test_video.unsqueeze(0),
#         "labels": torch.tensor(
#             [sample_test_video["label"]]
#         ),  # this can be skipped if you don't have labels available.
#     }
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     inputs = {k: v.to(device) for k, v in inputs.items()}
#     model = model.to(device)

#     # forward pass
#     with torch.no_grad():
#         outputs = model(**inputs)
#         logits = outputs.logits

#     return logits

# %%
# logits = run_inference(trained_model, sample_test_video["video"])

# %% [markdown]
# We can now check if the model got the prediction right.

# %%
# display_gif(sample_test_video["video"])

# %%
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])

# %% [markdown]
# And it looks like it got it right!
# 
# You can also use this model to bring in your own videos. Check out [this Space](https://huggingface.co/spaces/sayakpaul/video-classification-ucf101-subset) to know more. The Space will also show you how to run inference for a single video file.
# 
# <br><div align=center>
#     <img src="https://i.ibb.co/7nW4Rkn/sample-results.gif" width=700/>
# </div>

# %% [markdown]
# ## Next steps
# 
# Now that you've learned to train a well-performing video classification model on a custom dataset here is some homework for you:
# 
# * Increase the dataset size: include more classes and more samples per class.
# * Try out different hyperparameters to study how the model converges.
# * Analyze the classes for which the model fails to perform well.
# * Try out a different video encoder.
# 
# Don't forget to share your models with the community =)


