import argparse
from pathlib import Path
import time
from glob import glob
import os
import shutil
from torch.backends import cudnn
import random
import numpy as np
import jsonlines
import cv2
from PIL import ImageFile

import torch
import wandb  # Quit early if user doesn't have wandb installed.
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F

from models import distributed_utils
from models.loader import TextImageDataset
from models.pretrain_dataset import AllDataset

# libraries needed for webdataset support
from torchvision import transforms as T
from PIL import Image
from io import BytesIO

from clip import clip
import copy

# transformer model
from models import CoTransformer

import jsonlines
from tqdm import tqdm


# argument parsing

parser = argparse.ArgumentParser()

group = parser.add_mutually_exclusive_group(required=False)

group.add_argument('--model_load_path', default='', type=str,
                   help='path to your trained Transformer')

group.add_argument('--transformer_path', type=str,
                   help='path to your partially trained Transformer')

parser.add_argument('--data_dir', type=str, required=True,
                    help='path to your folder of frames')

parser.add_argument('--json_file', type=str, required=True,
                    help='path to your json file of captions and shots')

parser.add_argument('--output_dir', type=str, required=True,
                    help='path to save results')

parser.add_argument('--seed', type=int, default=42, help='Seed for random number')

parser.add_argument('--target_file', type=str, default = 'custom')

parser = distributed_utils.wrap_arg_parser(parser)

train_group = parser.add_argument_group('Training settings')

train_group.add_argument('--batch_size', default = 16, type = int, help = 'Batch size')

train_group.add_argument("--seq_len", type=int, default=10, help="Max length of sequence")

model_group = parser.add_argument_group('Model settings')

model_group.add_argument('--clip_model', default = "ViT-B/32", type = str, help = 'Name of CLIP')

model_group.add_argument('--hidden_size', default = 512, type = int, help = 'Model dimension')

model_group.add_argument('--image_size', default = 256, type = int, help = 'Size of image')

model_group.add_argument('--num_heads', default = 8, type = int, help = 'Model number of heads')

model_group.add_argument('--num_layers', default = 2, type = int, help = 'Model depth')

model_group.add_argument('--topk', default = 4, type = int)

model_group.add_argument('--threshold', default = 0.9, type = float)

model_group.add_argument('--weight', default = 1.0, type = float)

args = parser.parse_args()

# random seed
cudnn.benchmark = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)

seq_len = args.seq_len

FOURCC = cv2.VideoWriter_fourcc('m','p','4','v')
FPS = 30
VIDEO_SIZE = (1920, 1080)

shot_root = 'dataset/shots'

# helper fns

def exists(val):
    return val is not None

def check_length(sequence, mask) : 
    assert isinstance(sequence,list)
    
    if len(sequence) >= seq_len : 
        sequence = sequence[:seq_len]
    while len(sequence) < seq_len : 
        empty = torch.zeros_like(sequence[0])
        sequence.append(empty)
        mask[len(sequence) - 1] = 0
    
    sequence = torch.stack(sequence, dim = 0)
    
    return sequence, mask
    
# load model

clip_model, _ = clip.load(args.clip_model, jit=False)
clip_model.eval()
clip_model = clip_model.cuda()


TRANSFORMER_PATH = args.transformer_path

assert Path(TRANSFORMER_PATH).exists(), 'trained DALL-E must exist'

loaded_obj = torch.load(str(TRANSFORMER_PATH), map_location='cpu')

transformer_params, weights = loaded_obj['hparams'], loaded_obj['weights']

transformer_params = dict(
    **transformer_params
)

transformer = CoTransformer(**transformer_params) #DALLE(vae=vae, CLIP=None, clip_transform=clip_transform, **dalle_params)

transformer = transformer.cuda()

transformer.load_state_dict(weights)
transformer.eval()

# get dataset
i = 0
texts = []
shots = []
shot_names = []
gt_shots = []
ground_truth = {}

with open(args.json_file, "r", encoding="utf8") as f:
    for item in jsonlines.Reader(f): 
        texts.append(item['caption'])
        #ground_truth[item['caption']] = []
        for shot in item['shots'] : 
            shot_names.append(shot)
            frame_path = os.path.join(args.data_dir, shot, 'fea.npy')
            
            frames_fea = np.load(frame_path)
            frames_fea = torch.from_numpy(frames_fea)

            shots.append(frames_fea)
            
        gt_shots.append(item['shots'])
        
        i += 1


print('Load {} {}'.format(args.json_file, i))

# generate videos

for j, (text) in tqdm(enumerate(texts)) : 
    sentences = sentences = text.strip().split('.')
        
    for sentence in sentences : 
        if sentence == '' : 
            sentences.remove(sentence)
    
    for k in range(len(sentences)) : 
        sentences[k] = sentences[k] + '.'
    
    text_infos = []
    
    for sentence in sentences : 
        text_infos.append(clip.tokenize(sentence).squeeze(0))
    
    # clip embedding
    text_clip = torch.stack(text_infos, dim = 0).cuda()
    text_embeds = clip_model.encode_text(text_clip)
    text_embeds = torch.mean(text_embeds, dim = 0).unsqueeze(0).float()
    text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
    
    text_mask = torch.ones([10], dtype = torch.float)
    
    captions, text_mask = check_length(text_infos, text_mask)
    captions = captions.unsqueeze(dim = 0).cuda()
    text_mask = text_mask.unsqueeze(dim = 0).cuda()
    
    flag = 1
    threshold = args.threshold
    chosen_shots = []
    shot_list = []
    
    while flag : 
        max_l = 0
        max_index = 0
        for i, (shot) in enumerate(shots) : 
            if shot_names[i] in chosen_shots : 
                continue
            
            # clip embedding
            shot_embeds = torch.mean(shot, dim = 0).unsqueeze(0)
            shot_embeds = torch.nn.functional.normalize(shot_embeds, dim=-1)
            
            shot_mask = torch.ones([args.seq_len], dtype = torch.float)
            shot, shot_mask = check_length(shot_list + [shot], shot_mask)
            shot = shot.unsqueeze(dim = 0).cuda()
            shot_mask = shot_mask.unsqueeze(dim = 0).cuda()
            
            logits, _ = transformer(shot, shot, captions, None, None, text_mask, shot_mask, shot_mask, return_logist = True)
            logits = F.softmax(logits, dim = -1)
            
            sim = torch.mm(text_embeds.cpu(), shot_embeds.T).float()[0][0].item()
            
            score = logits[0, 1].item() + args.weight * sim
            
            if score > threshold :
                if score > max_l : 
                    max_l = score
                    max_index = i
        
        if (max_l == 0 and max_index == 0) or len(chosen_shots) >= args.seq_len: 
            flag = 0
            continue
        
        chosen_shots.append(shot_names[max_index])
    
    item = {}
    item['caption'] = text
    item['shots'] = chosen_shots
    
    # for statistic
    with jsonlines.open (args.output_dir, mode='a') as f :
        f.write(item)
    
    