import os
import argparse
import json
from tqdm import tqdm
from video_chatgpt.eval.model_utils import initialize_model, load_video
from video_chatgpt.inference import video_chatgpt_infer_batch
# from video_chatgpt.audio_transcript.transcribe import Transcriber
import torch
from add_param import get_config

def parse_args():
    """
    Parse command-line arguments.
    """
    parser = argparse.ArgumentParser()

    # Define the command-line arguments
    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
    parser.add_argument('--gt_caption', help='Path to the ground truth file containing answers.', required=True)
    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--conv_mode", type=str, required=False, default='pg-video-llava')
    parser.add_argument("--projection_path", type=str, required=True)
    parser.add_argument("--bs", type=int, default=4, help='Batch size for inference.')
    parser.add_argument("--use_asr", action='store_true', help='Whether to use audio transcripts or not')
    parser.add_argument("--debug", action='store_true', help='Whether to use audio transcripts or not')
    parser.add_argument("--tav", action='store_true', help='tell-a-video')
    parser.add_argument("-vis", "--save_spixel_visualization", action='store_true', help='tell-a-video')
    parser.add_argument("--temp_pool_type", type=str, default='uniform', help='tell-a-video')
    parser.add_argument("--k_means", type=int, default=20, help='tell-a-video')
    parser.add_argument("--max_spread_scale", type=int, default=4, help='tell-a-video')
    parser.add_argument("--rewrite", action='store_true', help='tell-a-video')
    parser.add_argument("--rewrite_per_sentence", action='store_true', help='tell-a-video')
    parser.add_argument("--end_marks", action='store_true', help='tell-a-video')
    parser.add_argument("--sam", type=str, default='na', choices=['na', 'mask_init', 'mask_gen'])

    parser.add_argument("--shift_unit", type=int, default=1, help='tell-a-video')
    parser.add_argument("--save_masks_and_features", action='store_true', help='tell-a-video')    
    parser.add_argument("--for_cond_gen_examples", action='store_true', help='tell-a-video')    
    parser.add_argument("--for_video_editing_examples", action='store_true', help='tell-a-video')    
    
    return parser.parse_args()


def run_inference(args):
    """
    Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.

    Args:
        args: Command-line arguments.
    """
    # Initialize the model
    if args.tav:
        model, vision_tower, tokenizer, image_processor, video_token_len, spixel_encoder, clip_text_package = initialize_model(args.model_name,
                                                                                        args.projection_path, args.SPIXEL)
    else:
        model, vision_tower, tokenizer, image_processor, video_token_len = initialize_model(args.model_name,
                                                                                        args.projection_path)

    
    if args.sam != 'na':
        from segment_anything_fast import SamAutomaticMaskGenerator, sam_model_fast_registry
        sam = sam_model_fast_registry["vit_b"](checkpoint="/your_path/segment-anything/sam_vit_b_01ec64.pth")
        sam.to(model.device)
        mask_generator = SamAutomaticMaskGenerator(sam)
        sam_package = {'sam': sam,
                        'mask_generator' : mask_generator,
                        'sam_type': args.sam}
    else:
        sam_package = {'sam_type': args.sam}        
    
    frame_size = (image_processor.crop_size['height'], image_processor.crop_size['width'])
    conv_mode = args.conv_mode
    
    # Load both ground truth file containing questions and answers
    with open(args.gt_caption) as file:
        gt_data = json.load(file)

    # Create the output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    video_formats = ['.mp4', '.avi', '.mov', '.mkv']

    # Iterate over each sample in the ground truth file
    predict = {}
    gt = {}    
    re_predict = {}
    masks_and_feats = []    

    # Iterate over each sample in the ground truth file
    batch_video_name = []
    batch_question = []
    batch_gt = []
    batch_video = []
    batch_id = []
    
    minibatch_index = 0
    i=0
    for sample in tqdm(gt_data):
        if args.for_davis_editing_examples:
            video_path = f"{sample['video_path']}"
            video_id_ = sample['vid']
            
        elif args.for_cond_gen_examples:
            video_path = f"{sample['video_path']}"
            video_id_ = sample['vid']
        else:
            video_id_ = sample['video_id']
            for fmt in video_formats:  # Added this line                                        
                temp_path = os.path.join(args.video_dir, f"{video_id_}{fmt}")
                if os.path.exists(temp_path):
                    video_path = temp_path
                    break        
        try:    
            if os.path.exists(video_path):
                video_frames = load_video(video_path, num_frm=100, shape=frame_size)
                batch_video.append(video_frames)                
                batch_video_name.append(video_id_)
                batch_question.append('Please describe this video.')
                
                if args.for_cond_gen_examples or args.for_davis_editing_examples:
                    gt_caption_ = sample['prompt']
                    batch_gt.append(gt_caption_)
                elif args.for_video_editing_examples:
                    gt_caption_ = sample['sentence']
                    batch_gt.append(gt_caption_)
                else:
                    gt_caption_ = sample['gt_caption']
                    batch_gt.append(gt_caption_)
                    
                batch_id.append(str(i))
                i+=1
                
        except Exception as e:
                print(f"Error processing video file : {e}")
                        
                        
        if len(batch_id) == args.bs:
            transcript_text=None
            
            if args.tav:
                # visualize the first batch
                save_spixel_visualization = (args.save_spixel_visualization and (minibatch_index == 0))
                
                output, save_dict = video_chatgpt_infer_batch(batch_video, batch_question, conv_mode, model, vision_tower,
                                tokenizer, image_processor, video_token_len, transcript_text, 
                                spixel_encoder=spixel_encoder, save_spixel_visualization=save_spixel_visualization,
                                temp_pool_type=args.temp_pool_type,
                                output_name=args.output_name,
                                debug=args.debug,
                                rewrite=args.rewrite,
                                end_marks=args.end_marks,
                                clip_text_package=clip_text_package,
                                sam_package=sam_package,
                                batch_video_name=batch_video_name,
                                save_masks_and_features=args.save_masks_and_features)
                
                masks_and_feats += save_dict
            else:
                output, save_dict = video_chatgpt_infer_batch(batch_video, batch_question, conv_mode, model, vision_tower,
                                    tokenizer, image_processor, video_token_len, transcript_text)

            for out, gt_, id, vn in zip(output, batch_gt, batch_id, batch_video_name):
                predict[f'{id}-{vn}'] = out
                gt[f'{id}-{vn}'] = gt_
                                        
            print(f'minibatch_index: {minibatch_index}, passed_videos: {i}')
            batch_video_name = []
            batch_question = []
            batch_gt = []
            batch_video = []
            batch_id = []
            minibatch_index +=1
            
            
    # Save the output list to a JSON file
    with open(os.path.join(args.output_dir, f"{args.output_name}_gt.json"), 'w') as file:
        json.dump(gt, file)

    with open(os.path.join(args.output_dir, f"{args.output_name}_pred.json"), 'w') as file:
        json.dump(predict, file)
        
    if args.save_masks_and_features:
        torch.save(masks_and_feats, os.path.join(args.output_dir, f"{args.output_name}_masks_and_features.pt"))
        
    if args.rewrite_per_sentence:
        with open(os.path.join(args.output_dir, f"{args.output_name}_pred_rewrite.json"), 'w') as file:
            json.dump(re_predict, file)
    
    
if __name__ == "__main__":
    args = parse_args()
    args = get_config(args)
    run_inference(args)
