from enum import Enum
from re import A

import peft

from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoConfig
from llm_logger import model_logger

from torch import bfloat16

import numpy as np


class TaskType(Enum):
    TOKEN_CLASSIFICATION = 1,
    TEXT_GENERATION = 2,
    SEQ_TO_SEQ = 3


def count_trainable_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params



def get_model(model_args, use_bart: bool = False):
    
    bnb_config_4 = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=bfloat16
            )
    
    if model_args.prefix:
        if use_bart:
            model_logger.info("Training with prefix for Seq2Seq")
            peft_config = peft.PrefixTuningConfig(
                task_type=peft.TaskType.SEQ_2_SEQ_LM, 
                inference_mode=False, 
                num_virtual_tokens=model_args.pre_seq_len, 
                prefix_projection=model_args.prefix_projection
                )
            model = AutoModelForSeq2SeqLM.from_pretrained(model_args.model_name_or_path)  # Train from scratch with config
        else:
            model_logger.info("Training with prefix for Causal LM")
            peft_config = peft.PrefixTuningConfig(
                task_type=peft.TaskType.CAUSAL_LM, 
                inference_mode=False, 
                num_virtual_tokens=model_args.pre_seq_len, 
                prefix_projection=model_args.prefix_projection
                )
            model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, quantization_config=bnb_config_4 if model_args.loading_4_bit else None)
        model = peft.get_peft_model(model, peft_config).to("cuda")
        model.print_trainable_parameters()
        return model, peft_config
    
        
    elif model_args.lora:
        if use_bart:
            model_logger.info("LoRA for Seq2Seq")
            peft_config = peft.LoraConfig(task_type=peft.TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_args.model_name_or_path)
        else:
            model_logger.info("LoRA for Causal LM")
            model_logger.info(f"Training {'WITH' if model_args.loading_4_bit else 'WITHOUT'} 4 bit loading")
            peft_config = peft.LoraConfig(task_type=peft.TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=16, lora_dropout=0)
            model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, quantization_config=bnb_config_4 if model_args.loading_4_bit else None)    
        model_logger.info(peft_config)
        model = peft.get_peft_model(model, peft_config).to("cuda")
        model.print_trainable_parameters()
        return model, peft_config

    elif model_args.mixed:
        model_logger.info("Mixed model chosen")
        # prefix config 
        prefix_config = peft.PrefixTuningConfig(
            task_type=peft.TaskType.CAUSAL_LM, 
            inference_mode=False, 
            num_virtual_tokens=model_args.pre_seq_len, 
            prefix_projection=model_args.prefix_projection
            )
        # peft config
        peft_config = peft.LoraConfig(
            task_type=peft.TaskType.CAUSAL_LM, 
            inference_mode=False, 
            r=model_args.lora_rank, 
            lora_alpha=16, 
            lora_dropout=0
            )
        
        model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, quantization_config=bnb_config_4 if model_args.loading_4_bit else None)
        model = peft.get_peft_model(model, peft_config).to("cuda")
        model.print_trainable_parameters()
        return model, prefix_config
    
    elif model_args.soft_prompt:
        if use_bart:
            model_logger.info("Soft Prompt model chosen for Seq2Seq")
            soft_prompt_config = peft.PromptTuningConfig(
                task_type=peft.TaskType.SEQ_2_SEQ_LM,
                prompt_tuning_init=peft.PromptTuningInit.RANDOM,
                inference_mode=False, 
                num_virtual_tokens=model_args.pre_seq_len
            )
            model = AutoModelForSeq2SeqLM.from_pretrained(model_args.model_name_or_path)
        else:
            model_logger.info("Soft Prompt model chosen for Causal LM")
            model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, quantization_config=bnb_config_4 if model_args.loading_4_bit else None) 
            soft_prompt_config = peft.PromptTuningConfig(
                task_type=peft.TaskType.CAUSAL_LM,
                prompt_tuning_init=peft.PromptTuningInit.RANDOM,
                inference_mode=False, 
                num_virtual_tokens=model_args.pre_seq_len
            )
        model_logger.info(soft_prompt_config)
        model = peft.get_peft_model(model, soft_prompt_config).to("cuda")
        model.print_trainable_parameters()
        return model, soft_prompt_config
    
    elif model_args.last_layer:
        model_logger.info("LAST LAYER FINE-TUNE")
        model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path).to("cuda")
        
        for params in model.parameters():
            params.requires_grad = False

        for param in model.embed_out.parameters():
            param.requires_grad = True    
        
        for name, param in model.named_parameters():
            if param.requires_grad == True:
                print(name)
                print(param.size())  
        return model, None
                 
    elif model_args.train_from_scratch:
        model_logger.info("TRAINING FROM SCRATCH model from scratch - we are cooler B-)")
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            quantization_config=bnb_config_4 if model_args.loading_4_bit else None
        )
        model = AutoModelForCausalLM.from_config(config).to("cuda") 
        return model, None

    else:   
        model_logger.info("Doing FULL finetuning because we are cool B-D")
        model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path).to("cuda")
        return model, None