from transformers import TrainingArguments, Trainer, GenerationConfig, AutoTokenizer, LlamaForCausalLM, DataCollatorForLanguageModeling
from transformers import Trainer, BitsAndBytesConfig
from peft import HOFTConfig, LoraConfig, get_peft_model
from datasets import load_dataset
from evaluate import load
import numpy as np
import argparse
import torch
import wandb
import math
import gc
import os
import re

def extract_last_number_or_ratio(s):
    # Find all sequences of digits, possibly with leading currency symbols, decimal points, and ratios
    patterns = re.findall(r'[\$€£]?\d+(?:\.\d+)?(?:\:\d+(?:\.\d+)?)?', s)
    
    # Return the last pattern found, or None if there are no matches
    if patterns:
        return patterns[-1]
    else:
        return None

def extract_last_option(s):
    # Find all sequences of digits, possibly with leading currency symbols, decimal points, and ratios
    patterns = re.findall(r'(A|B|C|D|E)', s)
    
    # Return the last pattern found, or None if there are no matches
    if patterns:
        return patterns[-1]
    else:
        return None

def main(args, summary):

    if args.use_wandb:
        wandb.init(
            project="QHOFT",
            name=f"llama{args.model}-{args.peft_type}-{args.dataset_name}",
            tags=[f"{args.peft_type}", f"llama{args.model}",
                  f"{args.dataset_name}"],
        )
        print(summary)

    # Load dataset
    if "orca" in args.dataset_name:
        raw_datasets = load_dataset(args.dataset_name, "default", trust_remote_code=True)
        max_tok_length = 500
    elif "gsm8k" in args.dataset_name:
        raw_datasets = load_dataset(args.dataset_name, "main", trust_remote_code=True)
        max_tok_length = 180

    

    if args.model == 2:
        checkpoint = "meta-llama/Llama-2-7b-hf"
    elif args.model == 3:
        checkpoint = "meta-llama/Llama-3.1-8B"

    tokenizer = AutoTokenizer.from_pretrained(
        checkpoint,
        padding=True,
        pad_to_multiple_of=8,
        truncation=True,
    )
    if args.model == 2:
        tokenizer.pad_token = tokenizer.eos_token
    if args.model == 3:
        tokenizer.pad_token = "<|finetune_right_pad_id|>"

    tokenizer.padding_side = "left"
    max_tok_length += 2 * max_tok_length + 2

    def preprocess4training_function(sample):

        sample_size = len(sample["question"])

        # Creating the prompt with the task description for each source sentence
        inputs = [f"{question}\n\nAnswer: " for question in sample["question"]]

        # Appending new line after each sample in the batch
        targets = [f"{answer}\n" for answer in sample["answer"]]

        # Applying the Llama2 tokenizer to the inputs and targets
        # to obtain "input_ids" (token_ids) and "attention mask"
        model_inputs = tokenizer(inputs)
        labels = tokenizer(targets)

        # Each input is appended with its target
        # Each target is prepended with as many special token id (-100) as the original input length
        # Both input and target (label) has the same max_tok_len
        # Attention mask is all 1s
        for i in range(sample_size):
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i] + [tokenizer.eos_token_id]
            model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
            labels["input_ids"][i] = [-100] * \
                len(sample_input_ids) + label_input_ids
            model_inputs["attention_mask"][i] = [
                1] * len(model_inputs["input_ids"][i])

        # Each input is applied left padding up to max_tok_len
        # Attention mask is 0 for padding
        # Each target (label) is left filled with special token id (-100)
        # Finally inputs, attention_mask and targets (labels) are truncated to max_tok_len
        for i in range(sample_size):
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i]
            model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
                max_tok_length - len(sample_input_ids)
            ) + sample_input_ids
            model_inputs["attention_mask"][i] = [0] * (max_tok_length - len(sample_input_ids)) + model_inputs[
                "attention_mask"
            ][i]
            labels["input_ids"][i] = [-100] * \
                (max_tok_length - len(sample_input_ids)) + label_input_ids
            model_inputs["input_ids"][i] = torch.tensor(
                model_inputs["input_ids"][i][:max_tok_length])
            model_inputs["attention_mask"][i] = torch.tensor(
                model_inputs["attention_mask"][i][:max_tok_length])
            labels["input_ids"][i] = torch.tensor(
                labels["input_ids"][i][:max_tok_length])
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    def preprocess4test_function(sample):
        inputs = [f"{question}\n\nAnswer: " for question in sample["question"]]
        model_inputs = tokenizer(inputs, padding=True)
        return model_inputs



    # Tokenize datasets
    if "gsm8k" in args.dataset_name:
        tokenized_train_dataset = raw_datasets['train']
        tokenized_val_dataset = tokenized_train_dataset.select(range(len(raw_datasets['train']) // 10)).map(
            preprocess4test_function, batched=True)
        tokenized_train_dataset = tokenized_train_dataset.select(range(len(raw_datasets['train']) // 10, len(raw_datasets['train']))).map(
            preprocess4training_function, batched=True)
        tokenized_test_dataset = raw_datasets['test'].map(
            preprocess4test_function, batched=True)

    elif "orca" in args.dataset_name:   
        tokenized_train_dataset = raw_datasets['train'].filter(lambda x: len(x["question"].split()) + len(x["answer"].split()) < 1000)
        tokenized_val_dataset = tokenized_train_dataset.select(range(10000,11000)).map(
                preprocess4test_function, batched=True)
        tokenized_test_dataset = tokenized_train_dataset.select(range(11000,13000)).map(
                preprocess4test_function, batched=True)
        tokenized_train_dataset = tokenized_train_dataset.select(range(10000)).map(
                preprocess4training_function, batched=True)

    else:
        raise ValueError(
            f'{args.dataset_name} is not available.')

    print(tokenized_train_dataset)
    print(tokenized_val_dataset)
    print(tokenized_test_dataset)

    if args.quantize_model:
        quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,  
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"  
        )
    else:
        quantization_config=None

    # Obtain model
    model = LlamaForCausalLM.from_pretrained(
        checkpoint, device_map='auto',  torch_dtype=torch.float16, quantization_config=quantization_config)
    
    generation_config = GenerationConfig.from_pretrained(checkpoint)

    if args.model == 3:
        model.config.use_cache = False
        model.config.pretraining_tp = 1
        generation_config.pad_token_id = tokenizer.pad_token_id

    # Load metrics

    def compute_metrics(sample, output_sequences):
        preds = tokenizer.batch_decode(
            output_sequences, skip_special_tokens=True)

        preds = [pred.strip() for pred in preds]

        for p, l in zip(preds, sample['answer']):
            print('Prediction: ', p)
            print('Reference: ', l)
            print('\n')

        if 'aqua_rat' in args.dataset_name:
            pred_num = [extract_last_option(pred) for pred in preds]
            label_num = [answer for answer in sample['correct']]
        else:
            pred_num = [extract_last_number_or_ratio(pred) for pred in preds]
            label_num = [extract_last_number_or_ratio(answer) for answer in sample['answer']]

        acc_result = np.mean([pred == label for pred, label in zip(pred_num, label_num)])

        result = {"accuracy": acc_result * 100}
        return result  # :)


    if args.peft_type == 'baseline':
        peft_model = model
    elif args.peft_type == 'hoft':
        config = HOFTConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            target_modules=["q_proj", "k_proj", "v_proj"],
            hoft_dropout=args.dropout,
            bias="none",
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'shoft':
        config = HOFTConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            target_modules=["q_proj", "k_proj", "v_proj"],
            hoft_dropout=args.dropout,
            bias="none",
            use_shoft=True
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'lora':
        config = LoraConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            lora_alpha=args.r * 2,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=args.dropout,
            bias="none",
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()
    elif args.peft_type == 'dora':
        config = LoraConfig(
            task_type="CAUSAL_LM",
            r=args.r,
            lora_alpha=args.r * 2,
            target_modules=["q_proj", "k_proj", "v_proj"],
            lora_dropout=args.dropout,
            bias="none",
            use_dora=True
        )
        peft_model = get_peft_model(model, config)
        peft_model.print_trainable_parameters()


    if args.peft_type != 'baseline':
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False,
            pad_to_multiple_of=8
        )
        train_args = TrainingArguments(
            f"models/llama{args.model}",
            eval_strategy="epoch",
            learning_rate=args.lr,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            weight_decay=0.00001,
            save_total_limit=0,
            warmup_steps=40,
            gradient_accumulation_steps=args.acc_steps,
            num_train_epochs=args.epochs,
            optim="adamw_torch",
            fp16=True,
            report_to="wandb" if args.use_wandb else None,  # enable logging to W&B
            logging_steps=args.log_steps,  # how often to log to W&B
        )

        trainer = Trainer(
            peft_model,
            train_args,
            train_dataset=tokenized_train_dataset,
            eval_dataset=tokenized_val_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
        print(peft_model)
        peft_model.train()
        trainer.train()

    peft_model.eval()

    test_batch_size = args.test_batch_size

    for name, dataset in [('eval', tokenized_val_dataset), ('test', tokenized_test_dataset)]:
        batch_tokenized_test = dataset.batch(
            test_batch_size).with_format("torch")
        number_of_batches = len(batch_tokenized_test["question"])
        output_sequences = []

        for i in range(number_of_batches):

            # Fix mismatched lengths
            if isinstance(batch_tokenized_test[i]["input_ids"], list):
                tensor_type = batch_tokenized_test[i]["input_ids"][0].dtype
                max_length = max(
                    tensor.shape[0] for tensor in batch_tokenized_test[i]["input_ids"])

                inputs_ids = torch.stack([torch.cat([tensor, torch.tensor([tokenizer.pad_token_id] * (
                    max_length - tensor.shape[0]))]) for tensor in batch_tokenized_test[i]["input_ids"]]).to(tensor_type)
                attention_mask = torch.stack([torch.cat([tensor, torch.tensor([0] * (max_length - tensor.shape[0]))])
                                             for tensor in batch_tokenized_test[i]["attention_mask"]]).to(tensor_type)
            else:
                inputs_ids = batch_tokenized_test[i]["input_ids"]
                attention_mask = batch_tokenized_test[i]["attention_mask"]

            with torch.no_grad():
                output_batch = peft_model.generate(
                    generation_config=generation_config,
                    input_ids=inputs_ids.cuda(),
                    attention_mask=attention_mask.cuda().half(),
                    max_new_tokens=max_tok_length,
                    num_beams=1,
                    do_sample=False,
                )
            output_sequences.extend(output_batch.cpu())

            del output_batch
            gc.collect()
            torch.cuda.empty_cache()

        # Comment if error arise
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        result = compute_metrics(dataset, output_sequences)

        if args.use_wandb:
            # Log BLEU and COMET scores
            wandb.log({
                f"{name}/accuracy": math.ceil(result["accuracy"] * 100) / 100,
            })

    if args.use_wandb:
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-model", help="llama model number", type=int, required=True)
    parser.add_argument("-dataset_name", help="source language", type=str, required=True)
    parser.add_argument("-lr", help="learning rate", type=float, default=1e-4)
    parser.add_argument("-r", help="rank used peft method", type=int, default=0)
    parser.add_argument("-acc_steps", help="number of gradient accumulations", type=int, default=1)
    parser.add_argument("-epochs", help="number of epochs", type=int, default=3)
    parser.add_argument("-batch_size", help="batch size", type=int, default=16)
    parser.add_argument("-test_batch_size", help="test batch size", type=int, default=64)
    parser.add_argument("-dropout", help="dropout for finetuned layers", type=float, default=0.0)
    parser.add_argument("-peft_type", help="select peft method", type=str, default='baseline')
    parser.add_argument('--use_wandb',  help="send results to wandb", action='store_true')
    parser.add_argument('--quantize_model',  help="whether to quantize or not the model", action='store_true')
    parser.add_argument("-log_steps", help="number of logging steps", type=int, default=3)

    args = parser.parse_args()
    summary = f"""

    LLaMA-{args.model} for math reasoning
          - dataset : {args.dataset_name}
          - quantize model : {args.quantize_model}
          
          - peft type : {args.peft_type}
          - rank (if needed) : {args.r}
          - dropout : {args.dropout}

          - epochs : {args.epochs}
          - learning rate : {args.lr}
          - test batch size : {args.test_batch_size}
          - batch size : {args.batch_size}
          - accumulation steps : {args.acc_steps}

          - use wandb : {args.use_wandb}
          - use logging steps : {args.log_steps}

    """
    print(summary)
    main(args, summary)
