from avs_v3 import AVSSDataset
from dataloader import S4Dataset, MS3Dataset
import torch
from pengi_wrapper import PengiWrapper as Pengi
import re
from tqdm import tqdm
import argparse
import pandas as pd
import os


def contains_alpha(string):
    return bool(re.search(r'[a-zA-Z]', string))


class TextExtractor:
    def __init__(self, device, prompt="this is a sound of", enable_uncertain=False):
        self.pengi = Pengi(config="base", use_cuda=(device=='cuda'), enable_uncertain=enable_uncertain)
        self.prompt = prompt

    def get_audio_embeddings(self, paths, lengths):
        audio_prefix, audio_embeddings = self.pengi.get_audio_embeddings(paths, clips=lengths)
        audio_embeddings = audio_embeddings[:lengths]
        return audio_embeddings

    def extract(self, paths, lengths):

        generated_responses = self.pengi.generate(audio_paths=paths,
                                    text_prompts=[self.prompt],
                                    add_texts=[""], 
                                    max_len=30, 
                                    beam_size=3, 
                                    temperature=1.0, 
                                    stop_token=' <|endoftext|>',
                                    clips=lengths
                                    )
        # print(generated_responses)
        # result = [[response[0][0] for response in generated_response] for generated_response in generated_responses]
        # scores = [[response[1][0].detach().cpu().numpy().item() for response in generated_response] for generated_response in generated_responses]

        all_result = []
        all_scores = []

        for generated_response in generated_responses:
            result = []
            scores = []
            for response in generated_response:
                flag = False
                for answer, confidence in zip(*response):
                    if contains_alpha(answer):
                        result.append(answer)
                        scores.append(confidence.detach().cpu().numpy().item())
                        flag = True
                        break
                if not flag:
                    result.append("")
                    scores.append(-100)
            all_result.append(result)
            all_scores.append(scores)

        # result = [response[0][0] for response in generated_responses]
        # scores = [response[1][0].detach().cpu().numpy().item() for response in generated_responses]
        return all_result, all_scores


def generate_audio_caption(ds='S4', split='train', batch_size=1, num_workers=2):
    if ds == 'S4':
        dataset = S4Dataset(split=split)
    elif ds == 'MS3':
        dataset = MS3Dataset(split=split)
    elif ds == 'V3':
        # dataset = V3Dataset_zs(split=split)
        dataset = AVSSDataset(split=split, subdomain='V3')
    else:   # AVSS/V2
        # dataset = V2Dataset(split=split)
        dataset = AVSSDataset(split=split, subdomain='AVSS')

    dataloader = torch.utils.data.DataLoader(dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True)

    model = TextExtractor(args.device, args.text_prompt)

    data = {
        'name': [],
        'text': [],
        'frame': [],
        'score': []
    }
    if ds == 'S4':
        data['category'] = []

    output_path = os.path.join(args.output_root, 'audio captions', ds, split)
    os.makedirs(output_path, exist_ok=True)
    output_csv = os.path.join(output_path, f"audio prompt = {args.text_prompt}.csv")
    # Save csv head first
    df = pd.DataFrame(data)
    df.to_csv(output_csv, index=False, header=True, mode='w')

    for n_iter, batch_data in tqdm(enumerate(dataloader), total=len(dataset) // batch_size, leave=False):
        if ds == 'S4':
            img_base_paths, labels, video_names, audio_paths, audio_clip_length, categories = batch_data
        elif ds == 'MS3':
            img_base_paths, labels, video_names, audio_paths, audio_clip_length = batch_data
        elif ds == 'V3':
            img_base_paths, labels, gt_temporal_mask_flag, video_names, audio_paths, audio_clip_length = batch_data
        else: # AVSS
            img_base_paths, labels, gt_temporal_mask_flag, video_names, audio_paths, audio_clip_length = batch_data

        if ds in ['S4', 'MS3']:
            audio_path = [os.path.join(ap, f'{video_name}.wav') for ap, video_name in zip(audio_paths, video_names)]
        else:
            audio_path = [os.path.join(ap, 'audio.wav') for ap in audio_paths]

        text_prompts, scores = model.extract(audio_path, audio_clip_length)

        data = {
            'name': [],
            'text': [],
            'frame': [],
            'score': []
        }

        if ds == 'S4':
            data['category'] = []
            for text_propmt, score, category, name in zip(text_prompts, scores, categories, video_names):
                for frame, (text, s) in enumerate(zip(text_propmt, score)):
                    data['name'].append(name)
                    data['category'].append(category.cpu().item())
                    data['text'].append(text)
                    data['frame'].append(frame)
                    data['score'].append(s)
        else:
            for text_propmt, score, name in zip(text_prompts, scores, video_names):
                for frame, (text, s) in enumerate(zip(text_propmt, score)):
                    data['name'].append(name)
                    data['text'].append(text)
                    data['frame'].append(frame)
                    data['score'].append(s)

        df = pd.DataFrame(data)
        df.to_csv(output_csv, index=False, header=False, mode='a')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--text_prompt', default='generate audio caption', type=str)
    parser.add_argument('--output_root', default='output', type=str)
    parser.add_argument('--dataset', default='S4', type=str)
    parser.add_argument('--split', default='train', type=str)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--num_workers', default=2, type=int)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print("Generate audio captions")
    generate_audio_caption(ds = args.dataset, split=args.split, batch_size=args.batch_size, num_workers=args.num_workers)
