#!/usr/bin/env python
import tqdm
import os

import torch
import torch.nn.parallel
import torch.distributed as dist

from utils import distributed_ops


def extract_features(loader, model, args, writer=None):

    loader.batch_sampler.set_epoch(epoch=0)
    # switch to eval mode
    model.eval()
    num_frames = args.data.args.num_frames
    dim = args.backbone.args.encoder_embed_dim

    # Compute embedding
    with torch.no_grad():
        for data in tqdm.tqdm(loader):
            # Skip already extracted one
            if os.path.exists(os.path.join(args.logging.embed_dir, data['data_uid'] + '.pt')):
                print(f"{data['data_uid']} is already.")
                continue

            batch_i = loader.batch_sampler.advance_batches_seen() - 1
            if torch.distributed.is_initialized():
                torch.distributed.barrier()

            _, f, c, h, w = data['video_data'].shape
            data['video_data'] = data['video_data'][0][:(f // num_frames * num_frames), ]
            data['video_data'] = data['video_data'].reshape(-1, num_frames, c, h, w)

            keys = set([k for k in data.keys() if "video"  in k or "label" in k])
            input_data = {k: v.cuda(args.environment.gpu, non_blocking=True) for k,v in data.items() if k in keys}
            outs = torch.zeros(data['video_data'].shape[0], dim)

            batch = 4
            times = data['video_data'].shape[0] // batch
            for j in range(times):
                start = j*batch
                if (j+1) * batch > data['video_data'].shape[0]:
                    end = data['video_data'].shape[0]
                else:
                    end = (j+1)*batch

                outs[start:end,] = model(input_data)['embedding']

            torch.save(outs, os.path.join(args.logging.embed_dir, data['video_uid'] + '.pt'))
            print(f"Saved {data['video_uid']}. ")

    if torch.distributed.is_initialized():
        torch.distributed.barrier()