import numpy as np
from transformers import AutoTokenizer
from models.pengi import PENGI
import os
import torch
from collections import OrderedDict
import librosa
from importlib_resources import files
import yaml
import argparse

class PENGIWrapper():
    """
    A class for interfacing Pengi model.
    """
    def __init__(self, model_path, use_cuda=False):
        self.file_path = os.path.realpath(__file__)
        self.config_path = files('configs').joinpath('config.yml')
        self.model_path = model_path
        self.use_cuda = use_cuda
        self.model, self.enc_tokenizer, self.dec_tokenizer, self.args = self.get_model_and_tokenizer(config_path=self.config_path)
    
    def read_config_as_args(self,config_path):
        return_dict = {}
        with open(config_path, "r") as f:
            yml_config = yaml.load(f, Loader=yaml.FullLoader)
        for k, v in yml_config.items():
            return_dict[k] = v
        return argparse.Namespace(**return_dict)

    def get_model_and_tokenizer(self, config_path):
        args = self.read_config_as_args(config_path)
        args.prefix_dim = args.d_proj
        args.total_prefix_length = 2*args.prefix_length

        # Copy relevant configs from dataset_config
        args.cuda = True if torch.cuda.is_available() else False
        args.dataset_config['cuda'] = args.cuda
        args.sampling_rate = args.dataset_config['sampling_rate']
        args.use_precomputed_melspec = args.dataset_config['use_precomputed_melspec']

        model = PENGI(
            # audio
            audioenc_name = args.audioenc_name,
            sample_rate = args.sampling_rate,
            window_size = args.window_size,
            hop_size = args.hop_size,
            mel_bins = args.mel_bins,
            fmin = args.fmin,
            fmax = args.fmax,
            classes_num = None,
            out_emb = args.out_emb,
            specaug = args.specaug,
            mixup = args.mixup,
            # text encoder
            use_text_encoder = args.use_text_model,
            text_encoder = args.text_model,
            text_encoder_embed_dim = args.transformer_embed_dim,
            freeze_text_encoder_weights = args.freeze_text_encoder_weights,
            # text decoder
            text_decoder = args.text_decoder,
            prefix_length = args.prefix_length,
            clip_length = args.prefix_length_clip,
            prefix_size = args.prefix_dim,
            num_layers = args.num_layers,
            normalize_prefix = args.normalize_prefix,
            mapping_type = args.mapping_type,
            freeze_text_decoder_weights = args.freeze_gpt_weights,
            # common
            d_proj = args.d_proj,
            use_pretrained_audioencoder = args.use_pretrained_audioencoder,
            freeze_audio_encoder_weights= args.freeze_audio_encoder_weights,
            use_precomputed_melspec = args.use_precomputed_melspec,
            pretrained_audioencoder_path = None,
        )
        model.enc_text_length = args.dataset_config['enc_text_len']
        model.dec_text_length = args.dataset_config['dec_text_len']
        model_state_dict = torch.load(self.model_path, map_location=torch.device('cpu'))['model']
        try:
            model.load_state_dict(model_state_dict)
        except:
            new_state_dict = OrderedDict()
            for k, v in model_state_dict.items():
                name = k[7:] # remove 'module.'
                new_state_dict[name] = v
            model.load_state_dict(new_state_dict)

        enc_tokenizer = AutoTokenizer.from_pretrained(args.text_model)
        if 'gpt' in args.text_model:
            enc_tokenizer.add_special_tokens({'pad_token': '!'})

        dec_tokenizer = AutoTokenizer.from_pretrained(args.text_decoder)
        if 'gpt' in args.text_decoder:
            dec_tokenizer.add_special_tokens({'pad_token': '!'})
        
        return model, enc_tokenizer, dec_tokenizer, args
    
    def get_audio_embed(self, audio_path):
        self.model.eval()
        with torch.no_grad():
            x, sr = librosa.load(audio_path, sr=44100)
            x = torch.tensor(x).reshape(1,-1)
            audio_embed, _ = self.model.audio_encoder(x)
            audio_embed = audio_embed / audio_embed.norm(2, -1).reshape(-1,1)
            audio_projections = self.model.caption_decoder.audio_project(audio_embed).contiguous().view(-1, self.model.caption_decoder.prefix_length, self.model.caption_decoder.gpt_embedding_size)
        return audio_projections

    def get_task_input_embed(self, input_text):
        self.model.eval()
        with torch.no_grad():
            if 'gpt' in self.args.text_model:
                input_text = input_text + ' <|endoftext|>'
            tok_inp_text = self.enc_tokenizer.encode_plus(text=input_text, add_special_tokens=True, return_tensors="pt")

            caption_embed = self.model.caption_encoder(tok_inp_text)
            caption_embed = caption_embed / caption_embed.norm(2, -1).reshape(-1,1)
            caption_embed = self.model.caption_decoder.text_project(caption_embed).contiguous().view(-1, self.model.caption_decoder.prefix_length, self.model.caption_decoder.gpt_embedding_size)
        
        return caption_embed

    def get_context_embed(self, context_text):
        self.model.eval()
        with torch.no_grad():
            tok_context = self.dec_tokenizer.encode_plus(text=context_text, add_special_tokens=True, return_tensors="pt")['input_ids']
            context_embed = self.model.caption_decoder.gpt.transformer.wte(tok_context)
        
        return context_embed
    
    def predict(self, audio_path, task_text, context, output_length, beam_size, temperature, stop_token):
        if stop_token is None:
            stop_token = ' <|endoftext|>'

        embeds = []
        audio_embed = self.get_audio_embed(audio_path)
        embeds.append(audio_embed)
        if task_text is not None and task_text != '':
            task_embed = self.get_task_input_embed(task_text)
            embeds.append(task_embed)
        if context is not None and context != '':
            context_embed = self.get_context_embed(context)
            embeds.append(context_embed)

        prefix_embed_context = torch.cat(embeds,axis=1)
        generated_captions, scores = self.generate_beam(embed=prefix_embed_context, beam_size=beam_size, temperature=temperature, stop_token=stop_token, entry_length=output_length)
        return generated_captions, scores
    
    def generate_beam(self,beam_size: int = 5, embed=None,
                  entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):

        self.model.eval()
        stop_token_index = self.dec_tokenizer.encode(stop_token)[0]
        tokens = None
        scores = None
        device = next(self.model.parameters()).device
        seq_lengths = torch.ones(beam_size, device=device)
        is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
        with torch.no_grad():
            generated = embed

            for i in range(entry_length):
                outputs = self.model.caption_decoder.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                logits = logits.softmax(-1).log()
                if scores is None:
                    scores, next_tokens = logits.topk(beam_size, -1)
                    generated = generated.expand(beam_size, *generated.shape[1:])
                    next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                    if tokens is None:
                        tokens = next_tokens
                    else:
                        tokens = tokens.expand(beam_size, *tokens.shape[1:])
                        tokens = torch.cat((tokens, next_tokens), dim=1)
                else:
                    logits[is_stopped] = -float(np.inf)
                    logits[is_stopped, 0] = 0
                    scores_sum = scores[:, None] + logits
                    seq_lengths[~is_stopped] += 1
                    scores_sum_average = scores_sum / seq_lengths[:, None]
                    scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
                    next_tokens_source = next_tokens // scores_sum.shape[1]
                    seq_lengths = seq_lengths[next_tokens_source]
                    next_tokens = next_tokens % scores_sum.shape[1]
                    next_tokens = next_tokens.unsqueeze(1)
                    tokens = tokens[next_tokens_source]
                    tokens = torch.cat((tokens, next_tokens), dim=1)
                    generated = generated[next_tokens_source]
                    scores = scores_sum_average * seq_lengths
                    is_stopped = is_stopped[next_tokens_source]
                next_token_embed = self.model.caption_decoder.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
                if is_stopped.all():
                    break
        scores = scores / seq_lengths
        output_list = tokens.cpu().numpy()
        output_texts = [self.dec_tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
        order = scores.argsort(descending=True)
        output_texts = [output_texts[i] for i in order]
        return output_texts, scores