import argparse
import torch
import os
import json
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from accelerate import Accelerator  # Added for data parallelism
from accelerate.utils import gather_object
from torch.nn import CrossEntropyLoss
from peft import get_peft_model, LoraConfig
from accelerate import DistributedDataParallelKwargs
from transformers import (
    AriaForConditionalGeneration,
    AriaProcessor,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoProcessor,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
    LlavaForConditionalGeneration,
    LlavaOnevisionForConditionalGeneration,
    Trainer,
    TrainerCallback,
    is_wandb_available,
)
from datasets import load_dataset
from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
import torch.distributed as dist

    
def add_adv_noise(model, optimizer,accelerator, completion, noise_area=1, noise_value=4, noise_step=20, noise_range=64):
    noise = torch.zeros_like(completion["pixel_values"], requires_grad=True, device = model.device)
    pixel_values = completion["pixel_values"]

    for i in range(noise_step):
        optimizer.zero_grad()
        completion["pixel_values"] = pixel_values + noise
        outputs= model(**completion)
        loss = outputs.loss

        accelerator.backward(loss)
        grad = noise.grad.detach()

        # # only keep the top noise_area% of the gradient, others are set to 0
        total_elements = grad.numel()
        k = int(noise_area * total_elements)  # Number of elements to keep
        grad_flat = grad.flatten()
        _, indices = torch.topk(grad_flat.abs(), k, largest=True, sorted=False)
        grad_zeroed = torch.zeros_like(grad_flat)
        grad_zeroed[indices] = grad_flat[indices]
        grad_modified = grad_zeroed.view_as(grad)
        grad = grad_modified

        d = torch.clamp(noise + noise_value/255 * torch.sign(grad), min=-noise_range/255, max=noise_range/255)
        noise.data = d
        noise.grad.zero_()
        
    noise.requires_grad = False
    return noise

def add_random_noise(model, completion, noise_range=64):
    noise = torch.randn_like(completion["pixel_values"], requires_grad=False, device = model.device) #* noise_range * 2 - noise_range
    return noise

def eval_model(args, accelerator):
    model_init_kwargs = {}
    model_init_kwargs["attn_implementation"] = "flash_attention_2"
    if isinstance(args.model_path, str):
        model_id = args.model_path
        model_init_kwargs["torch_dtype"] = torch.bfloat16

        # Disable caching if gradient checkpointing is enabled (not supported)
        model_init_kwargs["use_cache"] = (False )
        if "Qwen2-VL" in model_id:
            model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
        elif "Qwen2.5-VL" in model_id:
            model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
        elif "Aria" in model_id:
            model_init_kwargs.pop("use_cache")
            model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
        elif "llava-1.5" in model_id.lower():
            model_init_kwargs.pop("use_cache")
            model = LlavaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
        elif "llava-onevision" in model_id.lower():
            model_init_kwargs.pop("use_cache")
            model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        
    peft_config = LoraConfig(
        task_type="CAUSAL_LM",
        r=1,
        target_modules=["v_proj"],
        lora_alpha=32,
        lora_dropout=0,
        bias="none",
        use_rslora=False,
        modules_to_save=None,
    )

    if peft_config is not None:
        model.enable_input_require_grads()
        model = get_peft_model(model, peft_config)

    if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id or "llava" in model_id.lower():
        processing_class = AutoProcessor.from_pretrained(model_id)
        pad_token_id = processing_class.tokenizer.pad_token_id
        processing_class.pad_token_id = pad_token_id
        processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
        if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id:
            processing_class.image_processor.max_pixels = args.max_pixels
            processing_class.image_processor.min_pixels = 3136
    else:
        processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
        pad_token_id = processing_class.pad_token_id
    
    max_prompt_length = 8096
    generation_config = GenerationConfig(
            max_new_tokens=2048,
            do_sample=True,
            temperature=args.temperature,  # HACK
            num_return_sequences=args.num_responses,
            pad_token_id=pad_token_id,
        )


    # Load the dataset
    if "json" in args.query:
        dataset = load_dataset("json", data_files=args.query) #, name=script_args.dataset_config) 
    else:
        dataset = load_dataset(args.query) #, name=script_args.dataset_config)

    def make_conversation_image(example):
        return {
            "prompt": [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": example["problem"]},
                        # {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
                    ],
                },
            ],
        }
    
    dataset = dataset.map(make_conversation_image)  # Utilize multiprocessing for faster mapping

    def data_collator(features):  # No data collation is needed in GRPO
            return features
    
    dataloader = DataLoader(
        dataset["train"],
        batch_size=1,
        shuffle=False,
        collate_fn=data_collator
    )


    device = accelerator.device
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

    if accelerator.is_main_process:
        generated_data = []

    # error_flag = torch.zeros(1).to(device)
    with open(args.save_path, 'a') as f:
        for inputs in tqdm(dataloader):
            # error_flag.fill_(0)
            # try: # some images are too small to be processed for qwen
            prompts = [x["prompt"] for x in inputs]
            # prompts = [[{'content': [{'text': None, 'type': 'image'}, {'text': 'Can you elaborate on the elements of the picture provided?', 'type': 'text'}], 'role': 'user'}]]
            prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
            images = [Image.open(x["image"]) if isinstance(x["image"], str) else x["image"] for x in inputs]
            if "llava-onevision" in model_id:
                images = [image.resize((384, 384)) for image in images]
            prompt_inputs = processing_class(
                text=prompts_text,
                images=images,
                return_tensors="pt",
                padding=True,
                padding_side="left",
                add_special_tokens=False,
            )
            # prompt_inputs = super()._prepare_inputs(prompt_inputs)

            if max_prompt_length is not None:
                prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -max_prompt_length :]
                prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -max_prompt_length :]

            prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].to(torch.bfloat16)
            for key in prompt_inputs.keys():
                prompt_inputs[key] = prompt_inputs[key].to(model.device)

            completion = {}
            winner_ids = processing_class.tokenizer(inputs[0]["winner_answer"], return_tensors='pt', add_special_tokens=False)["input_ids"]
            winner_ids = torch.cat([winner_ids, torch.tensor([processing_class.tokenizer.eos_token_id]).unsqueeze(0).repeat(winner_ids.size(0), 1)], dim=1).to(model.device) # add eos token to keep the same format as generated completions
            completion["input_ids"] = torch.cat([prompt_inputs["input_ids"], winner_ids], dim=1)
            completion["labels"] = torch.cat([
                torch.full_like(prompt_inputs["input_ids"], -100),  # ignore loss on prompt tokens
                winner_ids], dim=1)
            completion["pixel_values"] = prompt_inputs["pixel_values"]
            if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id:
                completion["image_grid_thw"] = prompt_inputs["image_grid_thw"]
            elif "llava-onevision" in model_id:
                completion["image_sizes"] = prompt_inputs["image_sizes"]

            # add adv noise
            if args.noise_type == "adv":
                if args.noise_step > 0:
                    noise = add_adv_noise(model, optimizer, accelerator, completion, args.noise_area, args.noise_value, args.noise_step, args.noise_range)
                    prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"] + noise
            elif args.noise_type == "random":
                noise = add_random_noise(model, completion, args.noise_range)
                prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"] + noise
            
            # Generation
            with torch.inference_mode():
                output_ids = model.module.generate(
                    **prompt_inputs,
                    generation_config=generation_config
                )

            input_token_len = prompt_inputs["input_ids"].shape[1]
            outputs = processing_class.batch_decode(
                output_ids[:, input_token_len:], skip_special_tokens=True
            )

            # except Exception as e:
            #     print(e)
            #     print(inputs)
            #     error_flag.fill_(1)
            # dist.all_reduce(error_flag, op=dist.ReduceOp.SUM)
            # if error_flag.item() > 0:
            #     continue

            output = outputs[0].strip()
            output = output.strip()
            inputs[0]['loser_answer'] = output
            inputs[0].pop("prompt")

            accelerator.wait_for_everyone() 
            gathered_data = gather_object([inputs[0]])
            if accelerator.is_main_process:
                with open(args.save_path+"l", 'a') as f:
                    for gathered_data_item in gathered_data:
                        generated_data.append(gathered_data_item)
                        f.write(json.dumps(gathered_data_item)+"\n")
                    

    if accelerator.is_main_process:
        i = 0
        while i < len(generated_data):
            if generated_data[i]["loser_answer"] == "":
                generated_data.pop(i)
            i += 1
        with open(args.save_path, 'w') as f:
            json.dump(generated_data, f, indent=4)
        print(f"final save {len(generated_data)} data")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen2-VL-2B-Instruct")
    parser.add_argument("--save_path", type=str, default="")
    parser.add_argument("--query", type=str, default="")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--noise_type", type=str, default="adv")
    parser.add_argument("--noise_step", type=int, default=20)
    parser.add_argument("--noise_area", type=float, default=1)
    parser.add_argument("--noise_value", type=int, default=4)
    parser.add_argument("--noise_range", type=int, default=32)
    parser.add_argument("--num_responses", type=int, default=1)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max_pixels", type=int, default=401408)
    args = parser.parse_args()

    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
    # accelerator = Accelerator()  # Initialize Accelerator
    eval_model(args, accelerator)

