import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig, MllamaForConditionalGeneration


def name_lvlm(args):
    if args.lvlm_path == "allenai/Molmo-7B-D-0924":
        return "molmo"
    elif args.lvlm_path == "meta-llama/Llama-3.2-11B-Vision":
        return "llama"
    else:
        return "other"

def get_lvlm(args):
    if args.lvlm_path == "allenai/Molmo-7B-D-0924":
        processor = AutoProcessor.from_pretrained(
            args.lvlm_path,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map='auto',
        )
        model = AutoModelForCausalLM.from_pretrained(
            args.lvlm_path,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map='auto',
        )
    elif args.lvlm_path == "meta-llama/Llama-3.2-11B-Vision":
        lvlm_path = "../../Models/meta-llama/Llama-3.2-11B-Vision"
        processor = AutoProcessor.from_pretrained(
            lvlm_path,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map='auto',
        )
        model = MllamaForConditionalGeneration.from_pretrained(
            lvlm_path,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map="auto",
        )
    else:
        raise NotImplementedError

    return processor, model


def run_lvlm(args, model, processor, image, text):
    if args.lvlm_path == "allenai/Molmo-7B-D-0924":
        processed_inputs = processor.process(images=image, text=text)
        inputs = {}
        for key, value in processed_inputs.items():
            if key in ['input_ids', 'batch_idx', 'image_input_idx']:
                inputs[key] = value.to(model.device, dtype=torch.long).unsqueeze(0)
            else:
                inputs[key] = value.to(model.device, dtype=torch.float16).unsqueeze(0)
        output = model.generate_from_batch(
            inputs,
            GenerationConfig(max_new_tokens=args.max_new_tokens, stop_strings="<|endoftext|>"),
            tokenizer=processor.tokenizer
        )
        generated_tokens = output[0, inputs['input_ids'].size(1):]
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    elif args.lvlm_path == "meta-llama/Llama-3.2-11B-Vision":
        inputs = processor(image, text, return_tensors="pt").to(model.device)
        output = model.generate(
            **inputs,
            generation_config=GenerationConfig(max_new_tokens=args.max_new_tokens, stop_strings="<|eot_id|>"),
            tokenizer=processor.tokenizer
        )
        generated_tokens = output[0, inputs['input_ids'].size(1):]
        generated_text = processor.decode(generated_tokens, skip_special_tokens=True)

    else:
        raise NotImplementedError

    return generated_text

