import math
import wandb
import argparse
import re
import gc
import torch
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets, DatasetDict, Value
from peft import HOFTConfig, LoraConfig, get_peft_model
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling, AutoModelForCausalLM, AutoTokenizer



# ======================================================================================
# Data related functions
# ======================================================================================

def process_winogrande_dataset(sample):
    sample['user_prompt'] = f"{sample['sentence']}\nWhat does the _ in the above sentence refer to?\n1. {sample['option1']}\n2. {sample['option2']}"
    return sample

def process_boolq_dataset(sample):
    sample['user_prompt'] = f"Text: {sample['passage']}\nAnswer the following yes/no question: {sample['question']}?"
    return sample

def process_arc_dataset(sample):
    sample['user_prompt'] = f"{sample['question']}\n\nOptions:{"".join(f'\n\n{choice}: {text}' for choice, text in zip(sample['choices']['label'], sample['choices']['text']))}"
    sample['answer'] = sample['answerKey']
    return sample

def process_piqa_dataset(sample):
    sample['user_prompt'] = f"Goal: {sample['goal']}\n\nWhich is the correct ending?\n1. {sample['sol1']}\n2. {sample['sol2']}"
    sample['answer'] = str(int(sample['label']) + 1)
    return sample

def process_social_i_qa_dataset(sample):
    sample['user_prompt'] = f"Context: {sample['context']}\nQuestion: {sample['question']}\nWhich one of these answers best answers the question according to the context?\nA. {sample['answerA']}\nB. {sample['answerB']}\nC. {sample['answerC']}"
    answers = [sample['answerA'], sample['answerB'], sample['answerC']]
    sample['answer'] = answers[int(sample['label']) - 1]
    return sample

def process_hellaswag_dataset(sample):
    choices = ['A', 'B', 'C', 'D']
    sample['user_prompt'] = f"{sample['ctx_a']}\nWhich could be the most possible context for this action?{"".join(f'\n\n{choices[i]}. {sample['ctx_b']} {ending}' for i, ending in enumerate(sample['endings']))}"
    sample['answer'] = choices[int(sample['label'])]
    return sample

def process_obqa_dataset(sample):
    sample['user_prompt'] = f"{sample['question_stem']}\n\nOptions:{"".join(f'\n\n{choice}: {text}' for choice, text in zip(sample['choices']['label'], sample['choices']['text']))}"
    sample['answer'] = sample['answerKey']
    return sample

def process_general_dataset(sample):
    system_prompt = 'You are a helpful assistant. Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.'
    
    sample['answer_prompt'] = f'Answer: {sample['answer']}\n'
    sample['messages'] = [
            {'role' : 'system', 'content' : system_prompt},
            {'role' : 'user', 'content' : sample['user_prompt']},
        ]

    return sample

def load_hf_dataset(dataset_name):
    if dataset_name == 'winogrande':
        dataset = load_dataset("allenai/winogrande", "winogrande_l", trust_remote_code=True)
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['train'] = dataset['train'].map(process_winogrande_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['validation'].map(process_winogrande_dataset).remove_columns(column_names)
        dataset['test'] = dataset['validation']

        for split in dataset.keys():
            dataset[split] = dataset[split].cast_column("answer", Value("string"))

    elif dataset_name == 'boolq':
        dataset = load_dataset("google/boolq", trust_remote_code=True) 
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['test'] = dataset['validation'].map(process_boolq_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['train'].select(range(8000,len(dataset['train']))).map(process_boolq_dataset).remove_columns(column_names)
        dataset['train'] = dataset['train'].select(range(0, 8000)).map(process_boolq_dataset).remove_columns(column_names)

        for split in dataset.keys():
            dataset[split] = dataset[split].cast_column("answer", Value("string"))

    elif dataset_name == 'arce':
        dataset = load_dataset("allenai/ai2_arc", "ARC-Easy", trust_remote_code=True) 
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['train'] = dataset['train'].map(process_arc_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['validation'].map(process_arc_dataset).remove_columns(column_names)
        dataset['test'] = dataset['test'].map(process_arc_dataset).remove_columns(column_names)

    elif dataset_name == 'arcc':
        dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", trust_remote_code=True) 
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['train'] = dataset['train'].map(process_arc_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['validation'].map(process_arc_dataset).remove_columns(column_names)
        dataset['test'] = dataset['test'].map(process_arc_dataset).remove_columns(column_names)

    elif dataset_name == 'piqa':
        dataset = load_dataset("ybisk/piqa", trust_remote_code=True) 
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['train'] = dataset['train'].map(process_piqa_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['validation'].map(process_piqa_dataset).remove_columns(column_names)
        dataset['test'] = dataset['validation']

    elif dataset_name == 'social_i_qa':
        dataset = load_dataset("allenai/social_i_qa", trust_remote_code=True) 
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['test'] = dataset['validation'].map(process_social_i_qa_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['train'].select(range(30000,len(dataset['train']))).map(process_social_i_qa_dataset).remove_columns(column_names)
        dataset['train'] = dataset['train'].select(range(0, 30000)).map(process_social_i_qa_dataset).remove_columns(column_names)

    elif dataset_name == 'hellaswag':
        dataset = load_dataset("Rowan/hellaswag", trust_remote_code=True) 
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['train'] = dataset['train'].map(process_hellaswag_dataset).remove_columns(column_names)
        dataset['test'] = dataset['validation'].select(range(2000,len(dataset['validation']))).map(process_hellaswag_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['validation'].select(range(2000)).map(process_hellaswag_dataset).remove_columns(column_names)

    elif dataset_name == 'openbookqa':
        dataset = load_dataset("allenai/openbookqa", "main", trust_remote_code=True) 
        column_names = [name for name in dataset['train'].column_names if name != 'answer']
        dataset['train'] = dataset['train'].map(process_obqa_dataset).remove_columns(column_names)
        dataset['validation'] = dataset['validation'].map(process_obqa_dataset).remove_columns(column_names)
        dataset['test'] = dataset['test'].map(process_obqa_dataset).remove_columns(column_names)

    elif dataset_name == 'all':
        dataset_names = ["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "arcc", "arce", "openbookqa"]
        datasets = [load_hf_dataset(name) for name in dataset_names]


        combined_train_dataset = concatenate_datasets([dataset['train'] for dataset in datasets])
        combined_val_dataset = concatenate_datasets([dataset['validation'] for dataset in datasets])

        combined_dataset = DatasetDict({'train': combined_train_dataset, 'validation': combined_val_dataset})

        for name, dataset in zip(dataset_names, datasets):
            combined_dataset[name] = dataset['test']
        
        return combined_dataset
    else:
        print(f'An unexpected error has ocurred while choosing hf dataset {dataset_name}.')
        return None

    dataset['train'] = dataset['train'].map(process_general_dataset)
    dataset['validation'] = dataset['validation'].map(process_general_dataset)
    dataset['test'] = dataset['test'].map(process_general_dataset)
    return dataset


def tokenize_dataset(dataset, tokenizer, max_tok_length=100):

    max_tok_length += 2 * max_tok_length + 2

    def preprocess4training_function(sample):

        sample_size = len(sample["answer_prompt"])

        # Creating the prompt with the task description for each source sentence
        inputs = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in sample['messages']]

        # Appending new line after each sample in the batch
        targets = sample["answer_prompt"]

        # Applying the Llama2 tokenizer to the inputs and targets
        # to obtain "input_ids" (token_ids) and "attention mask"
        model_inputs = tokenizer(inputs)
        labels = tokenizer(targets)

        # Each input is appended with its target
        # Each target is prepended with as many special token id (-100) as the original input length
        # Both input and target (label) has the same max_tok_len
        # Attention mask is all 1s
        for i in range(sample_size):
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i] + [tokenizer.eos_token_id]
            model_inputs["input_ids"][i] = sample_input_ids + label_input_ids 
            labels["input_ids"][i] = [-100] * \
                len(sample_input_ids) + label_input_ids
            model_inputs["attention_mask"][i] = [
                1] * len(model_inputs["input_ids"][i])

        # Each input is applied left padding up to max_tok_len
        # Attention mask is 0 for padding
        # Each target (label) is left filled with special token id (-100)
        # Finally inputs, attention_mask and targets (labels) are truncated to max_tok_len
        for i in range(sample_size):
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i]
            model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
                max_tok_length - len(sample_input_ids)
            ) + sample_input_ids
            model_inputs["attention_mask"][i] = [0] * (max_tok_length - len(sample_input_ids)) + model_inputs[
                "attention_mask"
            ][i]
            labels["input_ids"][i] = [-100] * \
                (max_tok_length - len(sample_input_ids)) + label_input_ids
            model_inputs["input_ids"][i] = torch.tensor(
                model_inputs["input_ids"][i][:max_tok_length])
            model_inputs["attention_mask"][i] = torch.tensor(
                model_inputs["attention_mask"][i][:max_tok_length])
            labels["input_ids"][i] = torch.tensor(
                labels["input_ids"][i][:max_tok_length])
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

    def preprocess4test_function(sample):

        inputs = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in sample['messages']]
        model_inputs = tokenizer(inputs, padding=True)

        return model_inputs

    dataset['train'] = dataset['train'].map(preprocess4training_function, batched=True)
    for split in dataset.keys():
        if split != 'train':
            dataset[split] = dataset[split].map(preprocess4test_function, batched=True)

    return dataset


# ======================================================================================
# Model related functions
# ======================================================================================


def load_hf_model(model_name):
    if model_name == 'llama3.1':
        checkpoint = 'meta-llama/Llama-3.1-8B-Instruct'
    elif model_name == 'qwen2.5':
        checkpoint = 'Qwen/Qwen2.5-7B-Instruct'
    elif model_name == 'phi4':
        checkpoint = 'microsoft/phi-4'
    elif model_name == 'qwen2.5-14B':
        checkpoint = 'Qwen/Qwen2.5-14B'
    else:
        print(f'An unexpected error has ocurred while choosing hf model {model_name}.')
        return None, None
    
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.float16, device_map='cuda')

    if model_name == 'llama3.1':
        tokenizer.pad_token = "<|finetune_right_pad_id|>"
        tokenizer.pad_token_id = 128004

    return model, tokenizer

def obtain_peft_model(model, model_name, peft_type, rank, dropout, init_weights, verbose=True):

    if peft_type == 'base':
        return model

    if model_name in ['llama3.1', 'qwen2.5', 'qwen2.5-14B']:
        target_modules = ["q_proj", "k_proj", "v_proj", "up_proj", "down_proj"]
    elif model_name == 'phi4':
        target_modules = ["qkv_proj", "gate_up_proj", "down_proj"]
    else:
        print(f'An unexpected error has ocurred while choosing target modules for {model_name}.')
        return None

    if peft_type in ['lora', 'dora']:
        config = LoraConfig(
            task_type = "CAUSAL_LM",
            r = rank,
            lora_alpha = rank * 2,
            target_modules = target_modules,
            lora_dropout = dropout,
            bias = "none",
            init_lora_weights = init_weights,
            use_dora = peft_type == 'dora'  
        )

    elif peft_type in ['hoft', 'shoft']:
        config = HOFTConfig(
            task_type = "CAUSAL_LM",
            r = rank,
            target_modules = target_modules,
            hoft_dropout = dropout,
            init_weights = init_weights,
            bias = "none",
            use_shoft = peft_type == 'shoft'
        )  

    else:
        print(f'An unexpected error has ocurred while choosing peft {peft_type}.')
        return None
    
    peft_model = get_peft_model(model, config)

    if verbose:
        peft_model.print_trainable_parameters()
    
    return peft_model
    
# ======================================================================================
# Metric related functions
# ======================================================================================

def obtain_predictions_and_references(model, tokenizer, eval_dataset, batch_size, name='eval'):

    predictions = []
    model.eval()
    tokenizer.padding_side = 'left'

    with torch.no_grad():
        for i in tqdm(range(0, len(eval_dataset), batch_size), desc=f"Computing {name} predictions..."):
            messages = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in eval_dataset[i : min(i + batch_size, len(eval_dataset))]['messages']]
            model_inputs = tokenizer(messages, padding=True, truncation=True, return_tensors="pt").to('cuda')
            generated_ids = model.generate(**model_inputs, pad_token_id=tokenizer.pad_token_id, max_new_tokens=256, do_sample=False)
            generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
            predictions.extend(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))

            del generated_ids
            torch.cuda.empty_cache()
            gc.collect()

    return predictions, eval_dataset['answer']

def compute_accuracy(predictions, references, verbose=True):
        
    accuracy = 0
    for p, r in zip(predictions, references):
        clean_p = re.sub(r'Answer:', '', p).split('\n')[0].strip()
        clean_r = r.strip()

        if verbose:
            print(f'Prediction: -{clean_p}-')
            print(f'Reference:  -{clean_r}-')
            print()

        accuracy += int(clean_r == clean_p)

    return 100 * accuracy / len(predictions)

# ======================================================================================
# Main part
# ======================================================================================

def main(args, summary):

    # Initialize wandb
    if args.use_wandb:
        wandb.init(
            project="Commonsense reasoning",
            name=f"{args.model}-{args.peft_type}-{args.dataset}",
            tags=[f"{args.peft_type}", f"{args.model}",
                  f"{args.dataset}", f"{args.init_weights}"],
        )
    
    print(summary)

    # Load dataset
    dataset = load_hf_dataset(args.dataset)
    print(dataset)
    print(dataset['train'][0])
    
    # Load model, tokenizer and data collator
    model, tokenizer = load_hf_model(args.model)
    
    # Tokenize dataset
    tokenized_dataset = tokenize_dataset(dataset, tokenizer)

    # Load peft model
    model = obtain_peft_model(model, args.model, args.peft_type, args.r, args.dropout, args.init_weights)
    print(model)

    
    # Train model  
    if args.peft_type != 'base':

        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=8)

        train_args = TrainingArguments(
            f"models/{args.model}",
            eval_strategy="no",
            learning_rate=args.lr,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            weight_decay=1e-3,
            save_total_limit=0,
            warmup_steps=100,
            gradient_accumulation_steps=args.acc_steps,
            num_train_epochs=1,
            optim="adamw_torch",
            lr_scheduler_type='linear',
            fp16=True,
            report_to="wandb" if args.use_wandb else None, 
            logging_steps=args.log_steps, 
        )

        trainer = Trainer(
            model=model,
            args=train_args,
            processing_class=tokenizer,
            data_collator=data_collator,
            train_dataset=tokenized_dataset['train'],
        )

        trainer.train()

    # Evaluate model
    named_eval_datasets = [(name, dataset[name]) for name in dataset.keys() if name not in ['train']]

    for name, eval_dataset in named_eval_datasets:
        
        predictions, references = obtain_predictions_and_references(model, tokenizer, eval_dataset, args.test_batch_size, name)
        accuracy = compute_accuracy(predictions, references)

        if args.use_wandb:
            wandb.log({
                f"{name}/accuracy": math.ceil(accuracy * 100) / 100,
            })

    if args.use_wandb:
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-dataset", choices=["boolq", "piqa", "social_i_qa", "hellaswag", "winogrande", "arcc", "arce", "openbookqa", "all"], help="execution path", required=True)
    parser.add_argument("-model", choices=['llama3.1', 'phi4', 'qwen2.5', 'qwen2.5-14B'], help="LLM identifier for loading it", required=True)

    parser.add_argument("-lr", help="learning rate", type=float, default=1e-4)
    parser.add_argument("-r", help="rank used peft method",type=int, default=0)
    parser.add_argument("-acc_steps", help="number of gradient accumulations", type=int, default=1)
    parser.add_argument("-epochs", help="number of epochs",type=int, default=3)
    parser.add_argument("-batch_size", help="batch size", type=int, default=16)
    parser.add_argument("-dropout", help="dropout for finetuned layers", type=float, default=0.0)
    parser.add_argument("-peft_type", choices=['base', 'lora', 'dora', 'hoft', 'shoft', 'hra', 'oft'], help="select peft method", default='baseline')
    parser.add_argument("-init_weights", help="choose init_weights method", type=str, default='gaussian')

    parser.add_argument("-test_batch_size",help="test batch size", type=int, default=64)
    parser.add_argument("-log_steps", help="number of logging steps", type=int, default=3)
    parser.add_argument('--use_wandb',  help="send results to wandb", action='store_true')

    args = parser.parse_args()
    summary = f"""

    {args.model} for commonsense reasoning
    --------------------------------------------
          - dataset : {args.dataset}
          - model : {args.model}

          - peft type : {args.peft_type}
          - initialization : {args.init_weights}
          - rank : {args.r}
          - dropout : {args.dropout}

          - epochs : {args.epochs}
          - learning rate : {args.lr}
          - test batch size : {args.test_batch_size}
          - batch size : {args.batch_size}
          - accumulation steps : {args.acc_steps}

          - use wandb : {args.use_wandb}
          - use logging steps : {args.log_steps}
    """

    main(args, summary)