import os
import copy
import warnings
import shutil
from functools import partial

import torch
# import kornia.filters
from .model import load_pretrained_model
from .mm_utils import process_image, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, process_audio_file
from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, DEFAULT_AUDIO_TOKEN

import torch.nn.functional as F
import numpy as np

import torch

from sklearn.decomposition import PCA



def model_init(model_path=None, **kwargs):
    model_path = "DAMO-NLP-SG/VideoLLaMA2-7B" if model_path is None else model_path
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs)

    if tokenizer.pad_token is None and tokenizer.unk_token is not None:
        tokenizer.pad_token = tokenizer.unk_token

    num_frames = model.config.num_frames if hasattr(model.config, "num_frames") else NUM_FRAMES
    processor = {
        'image': partial(process_image, processor=processor, aspect_ratio=None),
        'video': partial(process_video, processor=processor, aspect_ratio=None, num_frames=num_frames),
        'audio': process_audio_file,
    }

    return model, processor, tokenizer


from torch.nn.functional import cosine_similarity
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt





import json  
def mm_infer(image_or_video, instruct, model, tokenizer, modal='video', number=None,layer_number = None, layer_to_mask =None, modality_to_mask=None, **kwargs):
    """inference api of VideoLLaMA2 for video understanding.

    Args:
        model: VideoLLaMA2 model.
        image_or_video (torch.Tensor): image tensor (1, C, H, W) / video tensor (T, C, H, W).
        instruct (str): text instruction for understanding video.
        tokenizer: tokenizer.
        do_sample (bool): whether to sample.
        modal (str): inference modality.
    Returns:
        str: response of the model.
    """

    # 1. text preprocess (tag process & generate prompt).
    if modal == 'image':
        modal_token = DEFAULT_IMAGE_TOKEN
    elif modal == 'video':
        modal_token = DEFAULT_VIDEO_TOKEN
    elif modal == 'text':
        modal_token = ''
    elif modal == 'audio':
        modal_token = DEFAULT_AUDIO_TOKEN
    else:
        raise ValueError(f"Unsupported modal: {modal}")

    # 1. vision preprocess (load & transform image or video).
    if modal == 'text':
        tensor = None
    else:
        if isinstance(image_or_video, dict):
            tensor = {k: v.half().cuda() for k, v in image_or_video.items()}
        else:
            tensor = image_or_video.half().cuda() 
        video_mask = tensor["video"]*0
       
        audio_mask = tensor["audio"]*0
    
      
        video_tensor = copy.deepcopy(tensor)
        video_tensor['video'] = video_mask
       
        audio_tensor = copy.deepcopy(tensor)
        audio_tensor['audio'] =   audio_mask 
      
        tensor = [(tensor, modal)]
        tensor_vmask = [(video_tensor, modal)]
        tensor_amask = [(audio_tensor, modal)]

    

    if isinstance(instruct, str):
        message = [{'role': 'user', 'content': modal_token + '\n' + instruct}]
        message_blank = [{'role': 'user', 'content': modal_token + '\n' + "You are a confused audio-visual detector."+instruct}]
    elif isinstance(instruct, list):
        message = [{'role': 'user', 'content': modal_token + '\n' + instruct[0]}]
    else:
        raise ValueError(f"Unsupported type of instruct: {type(instruct)}")

    if model.config.model_type in ['videollama2', 'videollama2_mistral', 'videollama2_mixtral']:
        system_message = [
            {'role': 'system', 'content': (
            """<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature."""
            """\n"""
            """If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>""")
            }
        ]
    else:
        system_message = []

    message = system_message + message
    message_blank = system_message+message_blank
    prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    prompt_blank = tokenizer.apply_chat_template(message_blank, tokenize=False, add_generation_prompt=True)

    input_ids = tokenizer_multimodal_token(prompt, tokenizer, modal_token, return_tensors='pt').unsqueeze(0).long().cuda()
    input_ids_blank = tokenizer_multimodal_token(prompt_blank, tokenizer, modal_token, return_tensors='pt').unsqueeze(0).long().cuda()

    attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()
    
    attention_masks_blank = input_ids_blank.ne(tokenizer.pad_token_id).long().cuda()
    keywords = [tokenizer.eos_token]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    do_sample = kwargs.get('do_sample', False)
    temperature = kwargs.get('temperature', 0.2 if do_sample else 0.0)
    top_p = kwargs.get('top_p', 0.9)
    max_new_tokens = kwargs.get('max_new_tokens', 2048)
    
 
   

    with torch.no_grad():


        generated_tokens = input_ids.clone()
        ques_leng = input_ids.shape[1]

     
        for series in range(max_new_tokens):  # Maximum response length
            attention_masks = generated_tokens.ne(tokenizer.pad_token_id).long().cuda()
         


            _, _, audio_hidden = model(
                            input_ids=generated_tokens,
                            attention_mask=attention_masks,
                            images=   tensor_vmask, 
                            return_dict=True,
                            modality = ["audio", int(layer_to_mask)],
                        )

            _, _, video_hidden = model(
                            input_ids=generated_tokens,
                            attention_mask=attention_masks,
                            images=tensor_amask,
                            return_dict=True,
                            modality =["video",  int(layer_to_mask)],
                        )
            
         

            instruction_mean =0.5*(video_hidden[int(layer_to_mask)][:,:14,:]+audio_hidden[int(layer_to_mask)][:,:14,:])  #+lang_hidden[int(layer_to_mask)][:,:14,:])
            prompt_mean = 0.5* (video_hidden[int(layer_to_mask)][:,2186:,:] +audio_hidden[int(layer_to_mask)][:,2186:,:])

            middle_v = 0.8*video_hidden[int(layer_to_mask)][:,14:14+676,:]+(1-0.8)*audio_hidden[int(layer_to_mask)][:,14:14+676,:]
            middle_a =0.8*audio_hidden[int(layer_to_mask)][:,14+676:2186,:]+(1-0.8)*video_hidden[int(layer_to_mask)][:,14+676:2186,:]
                        
            injection = torch.cat((instruction_mean,middle_v, middle_a, prompt_mean), dim=1)

            orig, _, _ = model(
                            input_ids=generated_tokens,
                            attention_mask=attention_masks,
                            images=tensor,
                            return_dict=True,
                            layer_num =int(layer_to_mask),
                            av_mask = injection
                        )
            next_token_logits = orig.logits[:, -1, :].clone()           

            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.argmax(probs, dim=-1).item()

            if next_token == tokenizer.eos_token_id:
                    break
                
            
            next_token_tensor = torch.tensor([[next_token]], device=generated_tokens.device, dtype=torch.long)
            generated_tokens = torch.cat([generated_tokens, next_token_tensor], dim=1)  # Concatenate along sequence dimension
 
  

    generated_tokens = generated_tokens[:,ques_leng:]
    outputs = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
    print("answer", outputs)
    return outputs
