import os
import yaml
from pynvml import *
import torch
from peft import LoraConfig
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorWithPadding, DataCollatorForSeq2Seq
from trl import DataCollatorForCompletionOnlyLM


class Llama:
    
    def __init__(self, dataset_name, data_size, task, use_quantized, use_peft, peft_method, r, density, model_name, seed):
        
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.data_size = data_size
        self.task = task
        self.seed = seed
        self.use_peft = use_peft
        self.peft_method = peft_method
        self.use_quantized = use_quantized
        self.peft_config = None
        self.bnb_config = None

        if use_peft:
            if peft_method == 'lora':
                self.peft_config = LoraConfig(
                    r=r,
                    lora_alpha=32, # jje: rank 1 did well with 16/1. try 4 times as much -- which could be bad for rank 1 but good for higher rank
                    lora_dropout=0.05,
                    bias="none",
                    task_type="CAUSAL_LM",
                    target_modules=['q_proj'] #['v_proj', 'up_proj', 'k_proj', 'gate_proj', 'q_proj', 'o_proj', 'down_proj']
                )
            elif peft_method == 'spiel':
                from peft import SftConfig, TaskType
                self.peft_config = SftConfig(
                    task_type=TaskType.CAUSAL_LM,
                    inference_mode=False,
                    reselection_steps=40,
                    density=density, # test 1e-7, 1e-6, 1e-5, 1e-4, 1e-3
                    selection_algorithm="rigl", # or "sm3" for moment approximation SFT
                    target_modules=["q_proj", "v_proj"],
                )

        if use_quantized:
            self.bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_quant_storage=torch.bfloat16,
                bnb_4bit_use_double_quant=True
            )

    def get_peft_config(self):
        
        return self.peft_config
    
    def get_datacollator(self, tokenizer, use_flash_attention=False, completion_only = True):

        padding_free = False
        if completion_only:
            if use_flash_attention:
                padding_free = True
            return DataCollatorForCompletionOnlyLM(response_template="<|start_header_id|>assistant<|end_header_id|>",tokenizer=tokenizer, padding_free=padding_free)
        else:
            return DataCollatorForSeq2Seq(tokenizer=tokenizer)
    

    def get_model_and_tokenizer(self, use_flash_attention=True, use_safetensors=False, load_dtype=None, on_vector=False):
        
        if on_vector:
            args = {'pretrained_model_name_or_path': "/model-weights/Meta-Llama-3.1-8B-Instruct" 
                }
        else:
            args = {'pretrained_model_name_or_path': self.model_name}

        if use_flash_attention:
            args['attn_implementation'] = "flash_attention_2"
        
        if use_safetensors:
            args['use_safetensors'] = use_safetensors

        # get the model
        if self.use_quantized:
            args['quantization_config'] = self.bnb_config
            args['torch_dtype'] = torch.bfloat16 if not load_dtype else load_dtype #torch.bfloat16    
        else:
            args['torch_dtype'] = torch.bfloat16 if not load_dtype else load_dtype
            #jje: temp fix for spiel support
            #args['torch_dtype'] = torch.float32 if self.peft_method=='spiel' else args['torch_dtype']
        
        print('using args: ', args)
        model = AutoModelForCausalLM.from_pretrained(**args)
        # get the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        
        if not tokenizer.pad_token:
            print(f"Padding tokenizer pad token to {tokenizer.eos_token_id}")
            tokenizer.pad_token_id = tokenizer.eos_token_id
        
        return model, tokenizer
    
    
    def get_data(self, tokenizer, max_seq_len):
        
        if 'json' in self.dataset_name:
            data = load_dataset('json', data_files=self.dataset_name)
        else: # then huggingface data path
            data = load_dataset(self.dataset_name, self.task)

        # if valid wrong name
        if 'validation' in data:
            data['valid'] = data.pop("validation")
        # if no valid 
        elif 'validation' not in data and 'valid' not in data:
            train_split = data['train'].train_test_split(0.2, seed=self.seed)
            data['train'] = train_split.pop('train')
            data['valid'] = train_split.pop('test')
        # if no test
        if 'test' not in data:
            test_split = data['valid'].train_test_split(0.1, seed=self.seed)
            data['valid'] = test_split.pop('train')
            data['test'] = test_split.pop('test')

        # limit the training data size
        if self.data_size:
            data['train'] = data['train'].shuffle(seed=self.seed).select(range(self.data_size))
        
        # scale down valid dataset size if greater than 20% of the train dataset
        if data['valid'].num_rows > data['train'].num_rows * 0.2:
            data['valid'] = data['valid'].shuffle(seed=self.seed).select(range(int(data['train'].num_rows * 0.2)))

        prompt = self.load_task_prompt()
        data = data.map(lambda x: {"text": tokenizer.apply_chat_template(self.format_chat_template(prompt=prompt, row=x),
                                                                    tokenize=False,
                                                                    add_special_token=False,
                                                                    add_generation_prompt=False)},
                                                                    load_from_cache_file=False,
                                                                    num_proc=os.cpu_count()
                                                                    )
        ## adding inefficiency to filter out long sequence data 
        data = data.map(lambda x: {"text_tok": tokenizer.apply_chat_template(self.format_chat_template(prompt=prompt, row=x),
                                                                    tokenize=True,
                                                                    add_special_token=False,
                                                                    add_generation_prompt=False)},
                                                                    load_from_cache_file=False,
                                                                    num_proc=os.cpu_count()
                                                                    )
        data = data.filter(lambda x: len(x['text_tok']) < max_seq_len)
        data = data.remove_columns(['text_tok'])

        return data
    
    def load_task_prompt(self):
        with open('../prompts/prompts_by_task.yaml','r') as f:
            prompt = yaml.safe_load(f)

        return prompt[self.dataset_name][self.task]

    def format_chat_template(self, prompt, row):

        row_json = None

        instruction = prompt['instruction']
        usr_input = prompt['user_input']
        assistant_output = prompt['assistant_output']

        row_json = [{"role": "system", "content": instruction },
                {"role": "user", "content": row[usr_input]},
                {"role": "assistant", "content": row[assistant_output]}]
        
        return row_json
