from data_from_json import PreferenceData, prefderence_data
import torch
from raie.core.utils import load_vlm
from torch.nn import functional as F
import wandb
from torch.optim import AdamW, Adam
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset
import os
import numpy as np
from trl import DPOTrainer, DPOConfig
from peft import LoraConfig, get_peft_model
from datasets import Features, Sequence, Value, ClassLabel, Dataset, features
from transformers.image_utils import load_image
from copy import deepcopy
from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoModelForVision2Seq, AutoModel, Idefics3Processor

import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]  # 确保输出为文本
)

def calculate_DPO_loss(model_prefered_logprob, model_disprefered_logprob,
                       ref_prefered_logprob, ref_disprefered_logprob,
                       beta=0.5):

    prefered_relative_logprob = model_prefered_logprob - ref_prefered_logprob
    disprefered_relative_logprob = model_disprefered_logprob - ref_disprefered_logprob

    reward_accuracies = (prefered_relative_logprob > disprefered_relative_logprob).float().mean(dim=-1)
    reward_margins = (prefered_relative_logprob - disprefered_relative_logprob).mean(dim=-1)

    loss = -F.logsigmoid(beta * (prefered_relative_logprob - disprefered_relative_logprob)).mean(dim=-1)

    return loss, prefered_relative_logprob.mean(dim=-1), disprefered_relative_logprob.mean(dim=-1), reward_accuracies, reward_margins


def get_log_prob(logits, labels):
    log_probs = F.log_softmax(logits, dim=-1)
    return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1).mean(-1)


def finetune(vlm, ref_vlm, processor, lr, train_dataloader, checkpoint, epochs=1, beta=0.1, resume_check='', batch_size=1, lora=False):

    if not os.path.exists(checkpoint):
        os.makedirs(checkpoint)

    if lora:
        lora_config = LoraConfig(target_modules='all-linear')
        vlm = get_peft_model(vlm, lora_config)
        vlm.print_trainable_parameters()

    dpo_config = DPOConfig(
        output_dir=checkpoint,
        beta=beta,
        learning_rate=lr,
        per_device_train_batch_size=batch_size,
        num_train_epochs=epochs,
        logging_steps=10,
        logging_dir=None,
        report_to='wandb',
        save_strategy='epoch',
    )

    trainer = DPOTrainer(
        model=vlm,
        ref_model=None,
        args=dpo_config,
        train_dataset=train_dataloader,
        processing_class=processor,
    )

    trainer.train()
    vlm.save_pretrained(checkpoint)
    print('done')
    # for epoch in range(epochs):
    #     loss_ = []
    #     with tqdm(train_dataloader, desc='Finetuning', ncols=100) as dbar:
    #         for chosen, rejected in dbar:
    #             optimizer.zero_grad()
    #
    #             chosen = chosen.to(vlm.device)
    #             rejected = rejected.to(vlm.device)
    #
    #             prefer_vlm_out = vlm(**chosen).logits
    #             rejected_vlm_out = vlm(**rejected).logits
    #
    #             vlm_prefer_log_prob = get_log_prob(prefer_vlm_out, chosen.input_ids)
    #             vlm_rejected_log_prob = get_log_prob(rejected_vlm_out, rejected.input_ids)
    #
    #             with torch.inference_mode():
    #                 prefer_refvlm_out = ref_vlm(**chosen).logits
    #                 rejected_refvlm_out = ref_vlm(**rejected).logits
    #
    #             refvlm_prefer_log_prob = get_log_prob(prefer_refvlm_out, chosen.input_ids)
    #             refvlm_rejected_log_prob = get_log_prob(rejected_refvlm_out, rejected.input_ids)
    #
    #             loss, prefered_relative_logprob, disprefered_relative_logprob, reward_accuracies, reward_margins = calculate_DPO_loss(
    #                 vlm_prefer_log_prob, vlm_rejected_log_prob,
    #                 refvlm_prefer_log_prob, refvlm_rejected_log_prob,
    #                 beta=beta)
    #
    #             loss.backward()
    #             optimizer.step()
    #
    #             loss_.append(loss.item())
    #
    #             dbar.set_postfix(loss=f'{np.mean(loss_):.4f}')
    #
    #             wandb.log({'loss': loss.item(),
    #                        'prefered_relative_logprob': prefered_relative_logprob,
    #                        'disprefered_relative_logprob': disprefered_relative_logprob,
    #                        'reward_accuracy': reward_accuracies,
    #                        'reward_margin': reward_margins
    #                        })
    #
    #     llm_state_dict = {name: param
    #                       for name, param in vlm.state_dict().items() if vlm.get_parameter(name).requires_grad}
    #
    #     if (epoch+1) % 1 == 0:
    #         torch.save(llm_state_dict, os.path.join(checkpoint, f'finetuned_checkpoint{epoch+1}.pth'))


def freezen_vm(vlm):
    for param in vlm.visual.parameters():
        param.requires_grad = False

    return vlm


def freezen_all(vlm):
    for param in vlm.parameters():
        param.requires_grad = False

    return vlm


def load_llava(path):
    processor = AutoProcessor.from_pretrained(path)
    model = LlavaForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16, device_map='auto')

    return model, processor

def load_ide(path, v=3):
    if v == 3:
        processor = Idefics3Processor.from_pretrained(path)
    else:
        processor = AutoProcessor.from_pretrained(path)
    model = AutoModelForVision2Seq.from_pretrained(path,  torch_dtype=torch.bfloat16, device_map='auto')

    return model, processor

def main(args):
    # vlm, vlm_processor = load_vlm(args.vlm_path)
    vlm, vlm_processor = load_ide(args.vlm_path, 3)
    # vlm, vlm_processor = load_llava(args.vlm_path)
    # vlm = freezen_vm(vlm)  # Finetune the llm only
    # ref_vlm, _ = load_ide(args.vlm_path)
    # ref_vlm = freezen_all(ref_vlm)


    def format(example):

        prompt = [example['prompt']]

        chosen = [[{"role": "assistant", "content": [{"type": "text", "text": f"{example['chosen']}"}]}]]
        rejected = [[{"role": "assistant",
                      "content": [{"type": "text", "text": f"{example['rejected']}"}]}]]

        prompt_texts = [vlm_processor.apply_chat_template(
            p, tokenize=False, add_generation_prompt=False,
        ) for p in prompt]
        # print(prompt_texts)
        chosen_texts = [vlm_processor.apply_chat_template(
            c, tokenize=False,
        ) for c in chosen]
        # print(chosen_texts)
        # chosen_texts = ["<|im_start|>assistant" + c.split('<|im_start|>assistant')[-1] for c in chosen_texts]
        rejected_texts = [vlm_processor.apply_chat_template(
            r, tokenize=False,
        ) for r in rejected]
        # rejected_texts = ["<|im_start|>assistant" + c.split('<|im_start|>assistant')[-1] for c in rejected_texts]

        images = load_image(example['images'])

        # a.append({"images": [images], "prompt": example['prompt'], "chosen": chosen, "rejected": rejected})
        return  {"images": [images], "prompt": prompt_texts[0], "chosen": chosen_texts[0],
                  "rejected": rejected_texts[0]}

    data = prefderence_data(args.preference_data_root)
    # data_ = format(data)
    # print(data_[0])
    # data_ = Dataset.from_list(data_)
    data_ = Dataset.from_list(data)
    data_= data_.map(format, num_proc=32)
    f = data_.features
    f["images"] = features.Sequence(features.Image(decode=True))
    data_ = data_.cast(f)
    print(data_.features)
    print(data_[0])
    finetune(vlm=vlm, ref_vlm=None, processor=vlm_processor, lr=args.lr, train_dataloader=data_, epochs=args.epochs,
             beta=args.beta, checkpoint=args.save_checkpoint_path, resume_check=args.resume_check, batch_size=args.batch_size, lora=args.lora)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()

    # parser.add_argument('--vlm_path', type=str, default='/home/15t/fzy/code/raie/Qwen2-VL-2B/Qwen2-VL-2B-Instruct')
    parser.add_argument('--vlm_path', default='/home/15t/fzy/code/raie/idefics3/Idefics3-8B-Llama3')
    # parser.add_argument('--preference_data_root', default=['/home/15t/fzy/code/raie/preference_json_data/realworld_new',
    #                                                        '/home/15t/fzy/code/raie/preference_json_data/mmbench_new',
    #                                                        '/home/15t/fzy/code/raie/preference_json_data/mmstar_new',
    #                                                        # '/home/15t/fzy/code/raie/preference_json_data/seedbench_new',
    #                                                        '/home/15t/fzy/code/raie/preference_json_data/science_new'])
    parser.add_argument('--preference_data_root', default='/home/15t/fzy/code/raie/preference_json_data/animals')
    parser.add_argument('--batch_size', default=1)
    parser.add_argument('--lr', default=5e-6)
    parser.add_argument('--epochs', default=5)
    parser.add_argument('--save_checkpoint_path', default='/home/15t/fzy/code/raie/checkpoint/dpo_animals_new_ide')
    parser.add_argument('--beta', default=0.1)
    parser.add_argument('--wandb_project', default='finetune_all')
    parser.add_argument('--resume_check', default='')
    parser.add_argument('--lora', default=True)
    args = parser.parse_args()

    wandb.login(key='426135167b028bcb0cfda0cc1f6c387893f4b5bf')
    wandb.init(project=args.wandb_project, config=args)

    main(args)



