import logging
import os
from torch.utils.data import Dataset
import json
import argparse
from datasets import load_dataset, Dataset
from transformers.trainer_utils import is_main_process
from joint_model import *
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    Qwen2ForCausalLM,
    Qwen2Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from transformers import Trainer
import logging


class CrossEntropyTrainer(Trainer):
    def __init__(self, combine_param1, combine_param2, *args, **kwargs):
        super(CrossEntropyTrainer, self).__init__(*args, **kwargs)
        self.combine_param1 = combine_param1
        self.combine_param2 = combine_param2

    def fixed_cross_entropy(self, source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
        reduction = "sum" if num_items_in_batch is not None else "mean"
        loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
        if reduction == "sum":
            loss = loss / num_items_in_batch
        return loss
    
    def reward_guided_cross_entropy(self, rewards, logits, labels):
        loss = nn.functional.cross_entropy(logits, labels, reduction='none')
        guided_loss = rewards * loss
        return guided_loss.mean()

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs2, outputs1, labels1, rewards = model(**inputs)
        logits2 = outputs2.logits.float()
        labels2 = inputs["labels"].to(logits2.device)
        loss2 = self.fixed_cross_entropy(
            logits2.view(-1, logits2.size(-1)),
            labels2.view(-1)
        )
        logits1 = outputs1.logits.float()
        labels1 = labels1.to(logits1.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).repeat_interleave(logits1.size(1)).to(logits1.device)
        loss1 = self.reward_guided_cross_entropy(
            rewards, 
            logits1.view(-1, logits1.size(-1)), 
            labels1.view(-1))
        loss = self.combine_param1 * loss1 + self.combine_param2 * loss2

        return (loss, outputs) if return_outputs else loss

        
class RawDataCollator:
    def __init__(self):
        pass

    def __call__(self, batch):
        meta_prompt = [item["meta_prompt"] for item in batch]
        meta_input_ids = torch.tensor([item["meta_input_ids"] for item in batch])
        meta_attention_mask = torch.tensor([item["meta_attention_mask"] for item in batch])
        question = [item["question"] for item in batch]
        answer = [item["answer"] for item in batch]
        labels = torch.tensor([item["labels"] for item in batch])

        return {
            "meta_prompt": meta_prompt,
            "meta_input_ids": meta_input_ids,
            "meta_attention_mask": meta_attention_mask,
            "question": question,
            "answer": answer,
            "labels": labels
        }


def preprocess_function(examples, tokenizer, max_length=512):
    meta_prompt = examples["meta_prompt"]
    question = examples["question"]
    answer = [f"{ans}<|im_end|>" for ans in examples["answer"]]
    meta_encodings = tokenizer(
        meta_prompt,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length
    )
    meta_input_ids = meta_encodings["input_ids"]
    meta_attention_mask = meta_encodings["attention_mask"]

    with tokenizer.as_target_tokenizer():
        labels_encodings = tokenizer(
            answer,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=max_length,
        )
    label = labels_encodings['input_ids']
    label = torch.tensor([[-100 if token == tokenizer.pad_token_id else token
                           for token in label_seq] for label_seq in label])

    return {
        "meta_prompt": meta_prompt,
        "meta_input_ids": meta_input_ids,
        "meta_attention_mask": meta_attention_mask,
        "question": question,
        "answer": answer,
        "labels": label
    }


def hyper_train(
        lr,
        encoding_dim,
        bottleneck,
        hyperlambda,
        pg_t,
        per_device_train_batch_size,
        gradient_accumulation_steps,
        model1_path,
        model2_path,
        tokenizer_path,
        train_file,
        output_dir,
        deepspeed_path,
        combine_param1,
        combine_param2

):
    temp_model2 = Qwen2ForCausalLM.from_pretrained(model2_path)
    hidden_size_model2 = temp_model2.config.hidden_size
    print(f"Model2 Hidden Size: {hidden_size_model2}")
    del temp_model2

    config = Qwen2JointConfig(
        encoding_dim=encoding_dim,
        input_dim=hidden_size_model2, 
        embedding_dim=hidden_size_model2,
        bottleneck=bottleneck,
        hyperlambda=hyperlambda,
        pg_t=pg_t,
        model1_path=model1_path,
        model2_path=model2_path,
        tokenizer_path=tokenizer_path,
    )
    max_length = 512
    model = Qwen2ForJointLM(config)
    tokenizer = Qwen2Tokenizer.from_pretrained(
        tokenizer_path,
        padding_side="left"
    )
    tokenizer.pad_token = tokenizer.eos_token

    raw_dataset = load_dataset('json', data_files=train_file)['train']
    train_dataset = raw_dataset.map(
        lambda x: preprocess_function(x, tokenizer, max_length),
        batched=True
    )
    data_collator = RawDataCollator()

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=1,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=lr,
        logging_dir=os.path.join(output_dir, "logs"),
        logging_steps=1,
        save_steps=500,
        save_total_limit=2,
        prediction_loss_only=True,
        fp16=True, 
        gradient_checkpointing=False,
        dataloader_num_workers=4,
        deepspeed=deepspeed_path,
    )

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    logger.addHandler(handler)
    logger.info(f"TrainingArguments: {training_args}")
    with open(deepspeed_path, "r") as f:
        ds_config = json.load(f)
    logger.info(f"DeepSpeed Config: {ds_config}")

    trainer = CrossEntropyTrainer(
        combine_param1=combine_param1,
        combine_param2=combine_param2,
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )
    trainer.train()
    trainer.save_model(output_dir)
    print("model trained and saved.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='hyper_train')

    parser.add_argument("--model1_path", type=str, default="path_to_initial_prompt_genrator")
    parser.add_argument("--model2_path", type=str, default="path_to_initial_actor_model")
    parser.add_argument("--tokenizer_path", type=str, default="path_to_initial_tokenizer")
    parser.add_argument("--train_data_path", type=str, default="../data/comosqa/train_rewrite.jsonl")
    parser.add_argument("--deepspeed_config", type=str, default="../../ds_config_z0.json")
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--base_model_path", type=str, default="./adapters")
    
    parser.add_argument("--model_name", type=str, default="hyper_test")
    parser.add_argument("--lr", type=float, default=1e-6)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--encoding_dim", type=int, default=8)
    parser.add_argument("--bottleneck", type=int, default=16)
    parser.add_argument("--hyperlambda", type=float, default=0.01)
    parser.add_argument("--pg_t", type=float, default=0)
    parser.add_argument("--per_device_train_batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
    parser.add_argument("--combine_param1", type=float, default=1)
    parser.add_argument("--combine_param2", type=float, default=1)


    args = parser.parse_args()
    print('Args in experiment:')
    print(args)

    model_save_path = f"{args.base_model_path}/{args.model_name}"

    hyper_train(
        lr=args.lr,
        encoding_dim=args.encoding_dim,
        bottleneck=args.bottleneck,
        hyperlambda=args.hyperlambda,
        pg_t=args.pg_t,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        model1_path=args.model1_path,
        model2_path=args.model2_path,
        tokenizer_path=args.tokenizer_path,
        train_file=args.train_data_path,
        output_dir=model_save_path,
        deepspeed_path=args.deepspeed_config,
        combine_param1=args.combine_param1,
        combine_param2=args.combine_param2
    )