from peft import LoraConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer import _is_peft_model
sys.path.append(".")
from source.fine_tuning.dataset import  load_dataset_multi_env_minigrid
import os
from accelerate import Accelerator
import torch
from datetime import datetime
import numpy as np
from source.fine_tuning.fix_dataset import _prepare_non_packed_dataloader
import argparse
from transformers import logging

os.environ["WANDB_PROJECT"] = "SFT_minigrid"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "false"  # log all model checkpoints
os.environ["WANDB_DIR"] = "/tmp/cache "  # save W&B files in the current directory
os.environ["WANDB_CONSOLE"] = "auto"  # turn off W&B console logs
accelerator = Accelerator()

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
logging.set_verbosity_info()

data_folder = "./data/datasets/"

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine tune the LLM")
    parser.add_argument("--name", type=str, default="BabyAI", help="The name of the dataset")
    parser.add_argument("--fraction", type=float, default=1, help="The fraction of the dataset to use")
    parser.add_argument("--fraction_goal", type=float, default=1, help="The fraction of the dataset to use")
    
    args = parser.parse_args()
    dataset_name = args.name
    fraction = args.fraction
    fraction_goal = args.fraction_goal

    os.makedirs("./results/models", exist_ok=True)


    train_dataset, eval_dataset = load_dataset_multi_env_minigrid(
        f"{data_folder}/{dataset_name}", single_env=True, frac_test_env=0.1, frac_test_goal=0.1,  fraction=fraction, fraction_goal =fraction_goal , obs=False, sequence=True, split_goal=False, frac_test_sample=0.01
    )


    print(
        f"Dataset loaded : {len(train_dataset)} training samples, {len(eval_dataset)} validation samples"
    )

    max_length =800 

    # Define model
    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

    lora_config = LoraConfig(
        r=512,
        lora_alpha=512,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map={"": accelerator.process_index},
        use_cache=False,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )



    tokenizer = AutoTokenizer.from_pretrained(model_id, padding=True, max_length=max_length)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"


    def formatting_prompts_func(example):
        output_texts = []
        for i in range(len(example["completion"])):
            text = example["prompt"][i] + str(example["completion"][i])
            output_texts.append(text)
        return output_texts


    train_dataset = _prepare_non_packed_dataloader(tokenizer, train_dataset, dataset_text_field=None, max_seq_length=max_length, formatting_func=formatting_prompts_func)
    eval_dataset = _prepare_non_packed_dataloader(tokenizer, eval_dataset, dataset_text_field=None, max_seq_length=max_length, formatting_func=formatting_prompts_func)



    collator = DataCollatorForCompletionOnlyLM(
        response_template="<|start_header_id|>assistant<|end_header_id|>", tokenizer=tokenizer, pad_to_multiple_of=2
    )

    def compute_accuracy(outputs, inputs, verbose=False):

        mask = inputs["labels"] != -100
        labels = [inputs["labels"][i][mask[i]] for i in range(mask.shape[0])]
        labels = tokenizer.batch_decode(labels)
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]-1):
                if mask[i,j] == False and mask[i,j+1] == True:
                    mask[i,j] = True
        predicted_sequences = tokenizer.batch_decode([outputs.logits[i][mask[i]].argmax(dim=-1) for i in range(mask.shape[0])])
        if verbose:
            for (p,l) in zip(predicted_sequences, labels):
                print(f"Predicted : {p} | Label : {l}")
        accuracy = np.mean([int(p[:len(l)] == l) for p, l in zip(predicted_sequences, labels)])
        return accuracy


    sft_config = SFTConfig(
        report_to="wandb",
        per_device_train_batch_size=10,  # 16
        per_device_eval_batch_size=24,  # 16
        learning_rate= 2e-5, 
        num_train_epochs=10,
        seed=seed,
        output_dir=f"./results/models/{dataset_name}",
        bf16=True,
        evaluation_strategy="steps",
        logging_strategy="steps",
        logging_steps=500,  # 10
        save_steps=2000,  # 10
        eval_steps=2000,
        load_best_model_at_end=True,
        save_total_limit=2,
        metric_for_best_model="eval_loss",
        gradient_accumulation_steps=1,
        eval_accumulation_steps=1,
        max_grad_norm = 0.3,
        warmup_ratio = 0.01,
        run_name=f"SFT-{dataset_name}_{fraction}_{fraction_goal}",
        remove_unused_columns=False,
        dataset_kwargs = {"skip_prepare_dataset":True},
    )

    os.makedirs(sft_config.output_dir, exist_ok=True)

    class SFTTrainer_custom(SFTTrainer):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self._acc_record = []
        def compute_loss(self, model, inputs, return_outputs=False):
            """
            How the loss is computed by Trainer. By default, all models return the loss in the first element.

            Subclass and override for custom behavior.
            """
            if self.label_smoother is not None and "labels" in inputs:
                labels = inputs.pop("labels")
            else:
                labels = None
            outputs = model(**inputs)
            # Save past state if it exists
            # TODO: this needs to be fixed and made cleaner later.
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index]

            if labels is not None:
                unwrapped_model = self.accelerator.unwrap_model(model)
                if _is_peft_model(unwrapped_model):
                    model_name = unwrapped_model.base_model.model._get_name()
                else:
                    model_name = unwrapped_model._get_name()
                if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                    loss = self.label_smoother(outputs, labels, shift_labels=True)
                else:
                    loss = self.label_smoother(outputs, labels)
            else:
                if isinstance(outputs, dict) and "loss" not in outputs:
                    raise ValueError(
                        "The model did not return a loss from the inputs, only the following keys: "
                        f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                    )
                # We don't use .loss here since the model may return tuples instead of ModelOutput.
                loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            
            
            if model.training and (self.state.global_step % self.state.logging_steps == 0):
                self._acc_record.append(compute_accuracy(outputs, inputs, verbose=True))
                self.log({"accuracy": float(np.mean(self._acc_record))})
                self._acc_record = []

            elif model.training:
                self._acc_record.append(compute_accuracy(outputs, inputs))
                
            return (loss, outputs) if return_outputs else loss



    trainer = accelerator.prepare(SFTTrainer_custom(
        model=model,
        args=sft_config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=lora_config,
        tokenizer=tokenizer,
        max_seq_length=max_length,
        data_collator=collator,
        formatting_func=formatting_prompts_func,

    )
    )


    trainer.train()


    ## Save model
    trainer.save_model(f"./results/models/SFT-{dataset_name}")
