from transformers import TrainingArguments, Trainer, GenerationConfig, AutoTokenizer, LlamaForCausalLM, DataCollatorForLanguageModeling
from transformers import Trainer
from peft import HOFTConfig, LoraConfig, get_peft_model
from datasets import load_dataset
from evaluate import load
import numpy as np
import argparse
import torch
import wandb
import math
import gc
import os
import re


def main(args, summary):

    # Load dataset
    raw_datasets = load_dataset(
        "facebook/covost2", f'{args.src}_{args.tgt}', data_dir=f'covost2/{args.src}', trust_remote_code=True)
    raw_datasets['train'] = raw_datasets['train'].map(
        remove_columns=["client_id", "file", 'audio', 'id'])
    raw_datasets['validation'] = raw_datasets['validation'].map(
        remove_columns=["client_id", "file", 'audio', 'id'])
    raw_datasets['test'] = raw_datasets['test'].map(
        remove_columns=["client_id", "file", 'audio', 'id'])

    max_tok_length = 60
    if args.model == 2:
        checkpoint = "meta-llama/Llama-2-7b-hf"
    elif args.model == 3:
        checkpoint = "meta-llama/Llama-3.1-8B"

    if args.src == 'sl':
        src_code = 'slovene'
    elif args.src == 'de':
        src_code = 'german'

        raw_datasets['train'] = raw_datasets['train'].take(10000)
        raw_datasets['validation'] = raw_datasets['validation'].take(1000)
    elif args.src == 'fr':
        src_code = 'french'

        raw_datasets['train'] = raw_datasets['train'].take(10000)
        raw_datasets['validation'] = raw_datasets['validation'].take(1000)
    elif args.src == 'lv':
        src_code = 'latvian'
    else:
        raise ValueError(
            f'{args.src_code} is not implemented, implemented src languages are sl and de.')

    if args.tgt == 'en':
        tgt_code = "english"
    else:
        raise ValueError(
            f'{args.tgt_code} is not implemented, implemented src languages is en.')

    tokenizer = AutoTokenizer.from_pretrained(
        checkpoint,
        padding=True,
        pad_to_multiple_of=8,
        truncation=True,
        max_length=max_tok_length
    )
    if args.model == 2:
        tokenizer.pad_token = tokenizer.eos_token
    if args.model == 3:
        tokenizer.pad_token = "<|finetune_right_pad_id|>"

    tokenizer.padding_side = "left"
    max_tok_length += 2 * max_tok_length + 2

    def preprocess4training_function(sample):

        sample_size = len(sample["sentence"])

        # Creating the prompt with the task description for each source sentence
        inputs = [f"Translate the following text from {src_code} to {tgt_code}:\n\n{src_code}: " +
                  sentence + f"\n\n{tgt_code}: " for sentence in sample["sentence"]]

        # Appending new line after each sample in the batch
        targets = [f"{s}\n" for s in sample["translation"]]

        # Applying the Llama2 tokenizer to the inputs and targets
        # to obtain "input_ids" (token_ids) and "attention mask"
        model_inputs = tokenizer(inputs)
        labels = tokenizer(targets)

        # Each input is appended with its target
        # Each target is prepended with as many special token id (-100) as the original input length
        # Both input and target (label) has the same max_tok_len
        # Attention mask is all 1s
        for i in range(sample_size):
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i]
            model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
            labels["input_ids"][i] = [-100] * \
                len(sample_input_ids) + label_input_ids
            model_inputs["attention_mask"][i] = [
                1] * len(model_inputs["input_ids"][i])

        # Each input is applied left padding up to max_tok_len
        # Attention mask is 0 for padding
        # Each target (label) is left filled with special token id (-100)
        # Finally inputs, attention_mask and targets (labels) are truncated to max_tok_len
        for i in range(sample_size):
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i]
            model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
                max_tok_length - len(sample_input_ids)
            ) + sample_input_ids
            model_inputs["attention_mask"][i] = [0] * (max_tok_length - len(sample_input_ids)) + model_inputs[
                "attention_mask"
            ][i]
            labels["input_ids"][i] = [-100] * \
                (max_tok_length - len(sample_input_ids)) + label_input_ids
            model_inputs["input_ids"][i] = torch.tensor(
                model_inputs["input_ids"][i][:max_tok_length])
            model_inputs["attention_mask"][i] = torch.tensor(
                model_inputs["attention_mask"][i][:max_tok_length])
            labels["input_ids"][i] = torch.tensor(
                labels["input_ids"][i][:max_tok_length])
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    def preprocess4test_function(sample):
        inputs = [f"Translate the following text from {src_code} to {tgt_code}:\n\n{src_code}: " +
                  sentence + f"\n\n{tgt_code}: " for sentence in sample["sentence"]]
        model_inputs = tokenizer(inputs, padding=True)
        return model_inputs

    # Tokenize datasets
    tokenized_train_dataset = raw_datasets['train'].map(
        preprocess4training_function, batched=True)
    tokenized_val_dataset = raw_datasets['validation'].map(
        preprocess4test_function, batched=True)
    tokenized_test_dataset = raw_datasets['test'].map(
        preprocess4test_function, batched=True)

    # Obtain model
    model = LlamaForCausalLM.from_pretrained(
        checkpoint, device_map='cuda', torch_dtype=torch.float16)
    generation_config = GenerationConfig.from_pretrained(checkpoint)
    if args.model == 3:
        model.config.use_cache = False
        model.config.pretraining_tp = 1
        generation_config.pad_token_id = tokenizer.pad_token_id

    # Load metrics
    bleu = load("sacrebleu")
    comet = load("comet")
    chrf = load("chrf")

    def compute_metrics(sample, output_sequences):
        inputs = [f"Translate the following text from {src_code} to {tgt_code}:\n\n{src_code}: " +
                  sentence + f"\n\n{tgt_code}: " for sentence in sample["sentence"]]
        preds = tokenizer.batch_decode(
            output_sequences, skip_special_tokens=True)



        for i, (input, pred) in enumerate(zip(inputs, preds)):
            pred = re.search(r'^.*\n', pred.removeprefix(input).lstrip())
            if pred is not None:
                preds[i] = pred.group()[:-1]
            else:
                preds[i] = ""

        preds = [pred.strip() for pred in preds]

        for p, l, s in zip(preds, sample['translation'], sample['sentence']):
            print('Prediction: ', p)
            print('Translation: ', l)
            print('Sentence: ', s)
            print('\n')

        bleu_result = bleu.compute(
            predictions=preds, references=sample["translation"])
        chrf_result = chrf.compute(
            predictions=preds, references=sample["translation"])
        comet_result = comet.compute(
            sources=sample["sentence"], predictions=preds, references=sample['translation'])

        result = {"bleu": bleu_result["score"], 'chrf': chrf_result['score'], "comet": np.mean(
            comet_result["scores"]) * 100}
        return result  # :)

    if args.use_wandb:
        wandb.init(
            project="MT",
            name=f"llama{args.model}-{args.peft_type}-{args.src}-to-{args.tgt}",
            tags=[f"{args.peft_type}", f"llama{args.model}",
                  f"{args.src}_{args.tgt}"],
        )
        print(summary)

    if args.peft_type == 'baseline':
        peft_model = model
    elif args.peft_type == 'hoft':
        config = HOFTConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            target_modules=["q_proj", "k_proj", "v_proj"],
            hoft_dropout=args.dropout,
            bias="none",
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'shoft':
        config = HOFTConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            target_modules=["q_proj", "k_proj", "v_proj"],
            hoft_dropout=args.dropout,
            bias="none",
            use_shoft=True
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'lora':
        config = LoraConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            lora_alpha=args.r * 2,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=args.dropout,
            bias="none",
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'dora':
        config = LoraConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            lora_alpha=args.r * 2,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=args.dropout,
            bias="none",
            use_dora=True
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()

    if args.peft_type != 'baseline':
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False,
            pad_to_multiple_of=8
        )
        train_args = TrainingArguments(
            f"models/llama{args.model}-{args.peft_type}-{args.src}-to-{args.tgt}",
            eval_strategy="epoch",
            learning_rate=args.lr,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            weight_decay=0.00001,
            save_total_limit=0,
            warmup_steps=100,
            gradient_accumulation_steps=args.acc_steps,
            num_train_epochs=args.epochs,
            optim="adamw_torch",
            fp16=True,
            report_to="wandb" if args.use_wandb else None,  # enable logging to W&B
            logging_steps=args.log_steps,  # how often to log to W&B
        )

        trainer = Trainer(
            peft_model,
            train_args,
            train_dataset=tokenized_train_dataset,
            eval_dataset=tokenized_val_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
        peft_model.train()
        trainer.train()

    peft_model.eval()
    test_batch_size = args.test_batch_size

    for name, dataset in [('eval', tokenized_val_dataset), ('test', tokenized_test_dataset)]:
        batch_tokenized_test = dataset.batch(
            test_batch_size).with_format("torch")
        number_of_batches = len(batch_tokenized_test["translation"])
        output_sequences = []

        for i in range(number_of_batches):

            # Fix mismatched lengths
            if isinstance(batch_tokenized_test[i]["input_ids"], list):
                tensor_type = batch_tokenized_test[i]["input_ids"][0].dtype
                max_length = max(
                    tensor.shape[0] for tensor in batch_tokenized_test[i]["input_ids"])

                inputs_ids = torch.stack([torch.cat([tensor, torch.tensor([tokenizer.pad_token_id] * (
                    max_length - tensor.shape[0]))]) for tensor in batch_tokenized_test[i]["input_ids"]]).to(tensor_type)
                attention_mask = torch.stack([torch.cat([tensor, torch.tensor([0] * (max_length - tensor.shape[0]))])
                                             for tensor in batch_tokenized_test[i]["attention_mask"]]).to(tensor_type)
            else:
                inputs_ids = batch_tokenized_test[i]["input_ids"]
                attention_mask = batch_tokenized_test[i]["attention_mask"]

            with torch.no_grad():
                output_batch = peft_model.generate(
                    generation_config=generation_config,
                    input_ids=inputs_ids.cuda(),
                    attention_mask=attention_mask.cuda(),
                    max_length=max_tok_length,
                    num_beams=1,
                    do_sample=False,
                )
            output_sequences.extend(output_batch.cpu())

            del output_batch
            gc.collect()
            torch.cuda.empty_cache()

        # Comment if error arise
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        result = compute_metrics(dataset, output_sequences)
        print(f'BLEU score: {result["bleu"]}')
        print(f'chrF score: {result["chrf"]}')
        print(f'COMET score: {result["comet"]}')

        if args.use_wandb:
            # Log BLEU and COMET scores
            wandb.log({
                f"{name}/bleu": math.ceil(result["bleu"] * 100) / 100,
                f"{name}/comet": math.ceil(result["comet"] * 100) / 100,
                f"{name}/chrf": math.ceil(result["chrf"] * 100) / 100
            })

    if args.use_wandb:
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-model", help="llama model number",
                        type=int, required=True)
    parser.add_argument("-src", help="source language",
                        type=str, required=True)
    parser.add_argument("-tgt", help="target language",
                        type=str, required=True)
    parser.add_argument("-lr", help="learning rate", type=float, default=1e-4)
    parser.add_argument("-r", help="rank used peft method",
                        type=int, default=0)
    parser.add_argument(
        "-acc_steps", help="number of gradient accumulations", type=int, default=1)
    parser.add_argument("-epochs", help="number of epochs",
                        type=int, default=3)
    parser.add_argument("-batch_size", help="batch size", type=int, default=16)
    parser.add_argument("-test_batch_size",
                        help="test batch size", type=int, default=64)
    parser.add_argument(
        "-dropout", help="dropout for finetuned layers", type=float, default=0.0)
    parser.add_argument("-peft_type", help="select peft method",
                        type=str, default='baseline')
    parser.add_argument(
        '--use_wandb',  help="send results to wandb", action='store_true')
    parser.add_argument(
        "-log_steps", help="number of logging steps", type=int, default=3)

    args = parser.parse_args()
    summary = f"""

    NLLB for MT
          - source language : {args.src}
          - target language : {args.tgt}

          - peft type : {args.peft_type}
          - rank (if needed) : {args.r}
          - dropout : {args.dropout}

          - epochs : {args.epochs}
          - learning rate : {args.lr}
          - test batch size : {args.test_batch_size}
          - batch size : {args.batch_size}
          - accumulation steps : {args.acc_steps}

          - use wandb : {args.use_wandb}
          - use logging steps : {args.log_steps}

    """
    print(summary)
    main(args, summary)
