import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from utils.myllama import MyLlamaForCausalLM
from utils.myqwen2 import MyQwen2ForCausalLM
from peft import PeftModel
from tqdm import tqdm
import itertools
import torch.nn as nn
from utils.config import DEFAULT_TOKENS, SPECIAL_DELM_TOKENS
from accelerate.utils import gather_object
from peft import LoraConfig, get_peft_model
# Initialize process group if distributed training is enabled
def init_distributed():
    if dist.is_available() and dist.is_initialized():
        return dist.get_rank(), dist.get_world_size()
    else:
        return 0, 1  # Default to single process if not initialized

# Get local rank and world size
local_rank, world_size = init_distributed()


peft_config = LoraConfig(
            task_type="CAUSAL_LM",
            r=64,
            lora_alpha=8,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj"],
        )

model_map = {
        "llama3.1_8b": {"hf":"meta-llama/Llama-3.1-8B", "hf_instruct":"meta-llama/Llama-3.1-8B-Instruct", "class": MyLlamaForCausalLM, "delimiters": [SPECIAL_DELM_TOKENS[1], "<|start_header_id|>assistant<|end_header_id|>"], "generation_prompt": "<|start_header_id|>assistant<|end_header_id|>"},
        "llama3.2_3b": {"hf":"meta-llama/Llama-3.2-3B", "hf_instruct":"meta-llama/Llama-3.2-3B-Instruct", "class": MyLlamaForCausalLM, "delimiters": [SPECIAL_DELM_TOKENS[1], "<|start_header_id|>assistant<|end_header_id|>"], "generation_prompt": "<|start_header_id|>assistant<|end_header_id|>"},
        "qwen2.5_7b": {"hf":"Qwen/Qwen2.5-7B", "hf_instruct":"Qwen/Qwen2.5-7B-Instruct", "class": MyQwen2ForCausalLM, "delimiters": [SPECIAL_DELM_TOKENS[1], "<|im_start|>assistant\n"], "generation_prompt": "<|im_start|>assistant\n"},
}

def get_ptm(model_name):

    # Tokenizer (take from instruct model)
    tokenizer = AutoTokenizer.from_pretrained(model_map[model_name]["hf"])
    tokenizer_it = AutoTokenizer.from_pretrained(model_map[model_name]["hf_instruct"])
    tokenizer.chat_template = tokenizer_it.chat_template 
    
    # Remove system lines
    sys_lines = []
    for l in tokenizer.chat_template.split("\n"):
        if "Cutting Knowledge" in l or "Today Date" in l:
            continue
        sys_lines.append(l)
    tokenizer.chat_template = "\n".join(sys_lines)

    # Model
    model = model_map[model_name]["class"].from_pretrained(model_map[model_name]["hf"])
    tokenizer.padding_side = "right" 


    # Set special tokens
    special_tokens_dict = dict()
    special_tokens_dict["additional_special_tokens"] = SPECIAL_DELM_TOKENS 
    special_tokens_dict["pad_token"] = DEFAULT_TOKENS["pad_token"]
    tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))


    model.generation_config.eos_token_id = tokenizer.eos_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id 
    model.config.eos_token_id = tokenizer.eos_token_id  
    model.config.pad_token_id = tokenizer.pad_token_id
    if model_name == "qwen2.5_7b":
        nn.init.normal_(model.model.token_type_embeddings.weight, std=0.1)
        for l in model.model.token_type_embeddings_layers:
            nn.init.normal_(l.weight, std=0.1)
    
    return model, tokenizer

def get_my_peft_model(model):
    model = get_peft_model(model, peft_config)
    model.model.model.token_type_embeddings.requires_grad_(True)
    for l in model.model.model.token_type_embeddings_layers:
        l.requires_grad_(True)
    return model

def save_model(trainer, args, instr_model_dir, model_dir, tokenizer):
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") if trainer.is_fsdp_enabled else None
    if args.peft:
        state_dict = trainer.accelerator.get_state_dict(trainer.model)
        if len(state_dict.keys()) > 1:
            model, _ = get_trained_model(instr_model_dir, args.model, bf16=False, is_train=True)
            model = get_peft_model(model, peft_config)
            state_dict = {k: v.cpu() for k, v in state_dict.items() if ("lora" in k or "token_type_embeddings" in k)}
            model.load_state_dict(state_dict, strict=False) 
            model = model.merge_and_unload()
            model.save_pretrained(model_dir) 
            tokenizer.save_pretrained(model_dir) 
    else:
        trainer.save_model(model_dir)

def get_trained_model(model_dir, model_name, adapter_dir=None, bf16=False, is_train=True, device=None):
    print(f"Loading model from {model_dir}")

    dtype = torch.bfloat16 if bf16 else torch.float32 

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    # Model
    if device is None:
        model = model_map[model_name]["class"].from_pretrained(model_dir, torch_dtype=dtype)
    else:
        model = model_map[model_name]["class"].from_pretrained(model_dir, device_map=device, torch_dtype=dtype)

    # Adapter
    if adapter_dir is not None:
        print(f"Loading adapter from {adapter_dir}")
        if device is None:
            model = PeftModel.from_pretrained(model, adapter_dir, torch_dtype=dtype)
        else:
            model = PeftModel.from_pretrained(model, adapter_dir, device_map=device, torch_dtype=dtype)
        tokenizer = AutoTokenizer.from_pretrained(adapter_dir) 
        
    if is_train:
        model = model.train()
        tokenizer.padding_side = "right"
    else:
        model = model.eval()
        tokenizer.padding_side = "left"  

    return model, tokenizer


# Single GPU generation
def run_generation(ds_batched, tokenizer, model, accelerator):
    results=dict(outputs=[])
    disable_pbar = True if accelerator.is_local_main_process else False

    for prompts in tqdm(ds_batched, disable=disable_pbar):
        assert tokenizer.padding_side == "left" # for generation, padding should be on the left
        prompts_tokenized = tokenizer.pad(prompts, padding='longest', return_tensors='pt').to(model.device)
        outputs_tokenized=model.generate(**prompts_tokenized, max_new_tokens=1024, do_sample=False, temperature=None, top_p=None)
        # remove prompt from gen. tokens
        outputs_tokenized=[ tok_out[len(tok_in):] 
            for tok_in, tok_out in zip(prompts_tokenized["input_ids"], outputs_tokenized) ] 
        results["outputs"].append(tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True))

    return [results]


# Single GPU generation
def run_generation_single(ds_batched, tokenizer, model):
    results=[]

    for prompts in tqdm(ds_batched, total=len(ds_batched)):
        assert tokenizer.padding_side == "left" # for generation, padding should be on the left
        prompts_tokenized = tokenizer.pad(prompts, padding='longest', return_tensors='pt').to(model.device)
        outputs_tokenized=model.generate(**prompts_tokenized, max_new_tokens=1024, do_sample=False, temperature=None, top_p=None)
        # remove prompt from gen. tokens
        outputs_tokenized=[ tok_out[len(tok_in):] 
            for tok_in, tok_out in zip(prompts_tokenized["input_ids"], outputs_tokenized) ] 
        results.extend(tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True))

    return results


# Multi-GPU generation
def run_generation_dp(ds_batched, tokenizer, model, accelerator):
    n = len(ds_batched[0]["input_ids"]) * len(ds_batched)

    accelerator.wait_for_everyone()

    with accelerator.split_between_processes(ds_batched) as ds_worker:
        resps_worker = run_generation(ds_worker, tokenizer, model, accelerator)

    accelerator.wait_for_everyone()
    resps_gather=gather_object(resps_worker)
    accelerator.wait_for_everyone()

    resps = []
    for r in resps_gather:
        resps += list(itertools.chain(*r["outputs"]))

    return resps[:n]
