from typing import Any, Dict, List
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    LlamaForCausalLM,
    BitsAndBytesConfig
)
import sys
import torch

from transformers import AutoTokenizer
import torch.distributed as dist

import gc

from peft import (
    PeftModel,
    get_peft_model, 
    prepare_model_for_int8_training, 
    prepare_model_for_kbit_training, 
    LoraConfig,
    AdaptionPromptConfig,
    PrefixTuningConfig, 
    TaskType
)
from vllm import LLM
from .utils import check_bf16_support
from .llm_worker import LLMModel

__all__ = [
    "LLAMAModel",
    "Vicuna",
    "Wizard",
    "GPT4ALL",
    "Guanaco",
    "Llama2",
    "Alpaca",
]

class LLAMAModel(LLMModel):
    def load_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name_or_path, use_fast=False, token=self.config['auth_token']
        )

        if self.tokenizer.pad_token is None:
            # LLAMA doesnot have pad token (https://github.com/huggingface/transformers/issues/22312)
            self.tokenizer.pad_token = "<unk>"
            self.tokenizer.pad_token_id = (
                0  # unk. we want this to be different from the eos token
            )

        self.tokenizer.padding_side = "left"  # Allow batched inference

        return self.tokenizer

    def load_model(self):
        if self.mode == 'train':
            if check_bf16_support():
                dtype = torch.bfloat16
            else:
                dtype = torch.float16

            if 'load_8bit' in self.config.keys()  and self.config['load_8bit']:
                quantization_config = BitsAndBytesConfig( 
                    load_in_8bit=True,
                    llm_int8_threshold=6.0
                ) 
            
            elif 'load_4bit' in self.config.keys() and self.config['load_4bit']:
                quantization_config = BitsAndBytesConfig( 
                    load_in_4bit=True, 
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4", 
                    bnb_4bit_compute_dtype=torch.float16, 
                ) 
            else:
                quantization_config = None

            self.model = LlamaForCausalLM.from_pretrained(
                self.model_name_or_path,
                quantization_config = quantization_config, 
                device_map="auto",
                torch_dtype=dtype,
                low_cpu_mem_usage=True,
                token=self.config['auth_token']
            )

            if quantization_config is not None:
                self.model = prepare_model_for_kbit_training(self.model)

            if self.config['peft']:
                self.apply_lora()

            if "delta_weights" in self.config:
                self.apply_delta(dtype, quantization_config)

            self.model.train()


        elif self.mode == 'inference':
            if check_bf16_support():
                dtype = "bfloat16"
            else:
                dtype = "float16"

            if dist.is_initialized():
                world_size = dist.get_world_size()
                gc.collect()
                dist.destroy_process_group()

                tensor_parallel_size = self.kwargs.get("tensor_parallel_size", 1)

                self.model = LLM(
                    model=self.model_name_or_path,
                    trust_remote_code=self.config.get("trust_remote_code", False),
                    dtype=dtype,
                    tensor_parallel_size=world_size,
                    tokenizer_mode="slow",
                )

            else:
                tensor_parallel_size = self.kwargs.get("tensor_parallel_size", 1)
                
                self.model = LLM(
                    model=self.model_name_or_path,
                    trust_remote_code=self.config.get("trust_remote_code", False),
                    tensor_parallel_size=tensor_parallel_size,
                    dtype=dtype,
                    tokenizer_mode="slow",
                )

            print(self.model)

        return self.model
    
class Alpaca(LLAMAModel):
    require_system_prompt = False

class Wizard(LLAMAModel):
    require_system_prompt = False

class GPT4ALL(LLAMAModel):
    require_system_prompt = False

class Guanaco(LLAMAModel):
    require_system_prompt = False

class Vicuna(LLAMAModel):
    require_system_prompt = False

class Llama2(LLAMAModel):
    require_system_prompt = False
