import os
import torch.nn as nn
os.environ['LD_LIBRARY_PATH'] = 'YOUR_CONDA_ENV/lib'
import sys
from typing import List
from packaging import version
import importlib.metadata
import importlib
import numpy as np
import fire
import json
import torch
import wandb
import cvxpy as cp
from tqdm import tqdm, trange

from fastchat.train.llama2_flash_attn_monkey_patch import (
    replace_llama_attn_with_flash_attn,
)

replace_llama_attn_with_flash_attn()

import transformers
from datasets import load_dataset, concatenate_datasets
from transformers import EarlyStoppingCallback, Trainer, GenerationConfig, TrainerCallback
from transformers.trainer_pt_utils import dataclass
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.utils import is_peft_available
from transformers.modeling_utils import unwrap_model
from peft import PeftModel
from sample_utils.Dual_sample import Dual_Sampler
from sample_utils.Prop_sample import Prop_Sampler
from sample_utils.MinGap_sample import MinGap_Sampler
# from transformers import AutoModel, AutoTokenizer
"""
Unused imports:`
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (  # noqa: E402
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402
import ipdb

def _is_peft_model(model):
    if is_peft_available():
        classes_to_check = (PeftModel,) if is_peft_available() else ()
        # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
        if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
            from peft import PeftMixedModel

            classes_to_check = (*classes_to_check, PeftMixedModel)
        return isinstance(model, classes_to_check)
    return False

@dataclass
class WeightLabelSmoother:
    """
    Adds label-smoothing on a pre-computed output from a Transformers model.

    Args:
        epsilon (`float`, *optional*, defaults to 0.1):
            The label smoothing factor.
        ignore_index (`int`, *optional*, defaults to -100):
            The index in the labels to ignore when computing the loss.
    """

    epsilon: float = 0.1
    ignore_index: int = -100

    def __call__(self, model_output, labels, batch_weight, shift_labels=False):
        logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
        if shift_labels:
            logits = logits[..., :-1, :].contiguous()
            labels = labels[..., 1:].contiguous()

        log_probs = -nn.functional.log_softmax(logits, dim=-1)
        if labels.dim() == log_probs.dim() - 1:
            labels = labels.unsqueeze(-1)

        padding_mask = labels.eq(self.ignore_index)
        # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
        # will ignore them in any case.
        labels = torch.clamp(labels, min=0)
        nll_loss = log_probs.gather(dim=-1, index=labels)
        # works for fp16 input tensor too, by internally upcasting it to fp32
        smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)

        nll_loss.masked_fill_(padding_mask, 0.0)
        #print(nll_loss.shape)
        smoothed_loss.masked_fill_(padding_mask, 0.0)

        weight = torch.FloatTensor(batch_weight).to(nll_loss.device)
        weight = weight.unsqueeze(-1).unsqueeze(-1)

        nll_loss = nll_loss * weight
        smoothed_loss = smoothed_loss * weight

        # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
        num_active_elements = padding_mask.numel() - padding_mask.long().sum()
        nll_loss = nll_loss.sum() / num_active_elements
        #print(nll_loss.shape)
        #print("========================")
        #exit(0)
        smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
        return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss



def get_embedding(text, tokenizer, model):
    tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(model.device)
    input_ids = tokens["input_ids"]
    attention_mask = tokens['attention_mask']
    with torch.no_grad():
        output = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        embedding = output['hidden_states'][-1]
        embedding = torch.mean(embedding, dim=1, keepdim=False)

    #print(type(embedding))
    #exit(0)
    return embedding.half()
    #return emb




def train(
        # model/data params
        base_model: str = "",  # the only required argument
        train_data_path: List[str] = [""],
        val_data_path: List[str] = [""],
        output_dir: str = "./lora-alpaca",
        sample: int = -1,
        seed: int = 0,
        # training hyperparams
        batch_size: int = 128,
        micro_batch_size: int = 4,
        num_epochs: int = 3,
        learning_rate: float = 3e-4,
        cutoff_len: int = 512,
        # lora hyperparams
        lora_r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.05,
        lora_target_modules: List[str] = [
            "q_proj",
            "v_proj",
        ],
        # llm hyperparams
        train_on_inputs: bool = True,  # if False, masks out inputs in loss
        group_by_length: bool = False,  # faster, but produces an odd training loss curve
        # wandb params
        wandb_project: str = "",
        wandb_run_name: str = "",
        wandb_watch: str = "",  # options: false | gradients | all
        wandb_log_model: str = "",  # options: false | true
        resume_from_checkpoint: str = None,  # either training checkpoint or final adapter

        #sample strategies
        category_type: str = "", #options: Dual | FairCo
        topk: int = 20,
):
    print(
        f"Training Alpaca-LoRA model with params:\n"
        f"base_model: {base_model}\n"
        f"train_data_path: {train_data_path}\n"
        f"val_data_path: {val_data_path}\n"
        f"sample: {sample}\n"
        f"seed: {seed}\n"
        f"output_dir: {output_dir}\n"
        f"batch_size: {batch_size}\n"
        f"micro_batch_size: {micro_batch_size}\n"
        f"num_epochs: {num_epochs}\n"
        f"learning_rate: {learning_rate}\n"
        f"cutoff_len: {cutoff_len}\n"
        f"lora_r: {lora_r}\n"
        f"lora_alpha: {lora_alpha}\n"
        f"lora_dropout: {lora_dropout}\n"
        f"lora_target_modules: {lora_target_modules}\n"
        f"train_on_inputs: {train_on_inputs}\n"
        f"group_by_length: {group_by_length}\n"
        f"wandb_project: {wandb_project}\n"
        f"wandb_run_name: {wandb_run_name}\n"
        f"wandb_watch: {wandb_watch}\n"
        f"wandb_log_model: {wandb_log_model}\n"
        f"resume_from_checkpoint: {resume_from_checkpoint}\n"
    )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
    gradient_accumulation_steps = batch_size // micro_batch_size
    # print(f"gradient_accumulation_steps: {gradient_accumulation_steps}")

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
            "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model

    # wandb.init(project="big-rec-MMF-0311", name="test3")
    # os.environ["WANDB_DISABLED"] = "true"

    model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map=device_map,
    )
    # model.set_tau(tau)
    tokenizer = LlamaTokenizer.from_pretrained(base_model)

    tokenizer.pad_token_id = (
        0  # unk. we want this to be different from the eos token
    )
    tokenizer.padding_side = "left"  # Allow batched inference

    def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
                result["input_ids"][-1] != tokenizer.eos_token_id
                and len(result["input_ids"]) < cutoff_len
                and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(data_point):
        full_prompt = generate_prompt(data_point)
        tokenized_full_prompt = tokenize(full_prompt)
        if not train_on_inputs:
            user_prompt = generate_prompt({**data_point, "output": ""})
            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])

            tokenized_full_prompt["labels"] = [
                                                  -100
                                              ] * user_prompt_len + tokenized_full_prompt["labels"][
                                                                    user_prompt_len:
                                                                    ]  # could be sped up, probably
        return tokenized_full_prompt

    model = prepare_model_for_int8_training(model)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config)

    train_data_list = []
    val_data_list = []

    for path in train_data_path:
        if path.endswith(".json"):
            train_data_list.append(load_dataset("json", data_files=path))
        else:
            train_data_list.append(load_dataset(path))


    for path in val_data_path:
        if path.endswith(".json"):
            val_data_list.append(load_dataset("json", data_files=path))
        else:
            val_data_list.append(load_dataset(path))

    for i in range(len(train_data_list)):
        train_data_list[i]["train"] = train_data_list[i]["train"].shuffle(seed=seed).select(
            range(sample)) if sample > -1 else train_data_list[i]["train"].shuffle(seed=seed)
        train_data_list[i]["train"] = train_data_list[i]["train"].shuffle(seed=seed)
        train_data_list[i] = train_data_list[i].map(lambda x: generate_and_tokenize_prompt(x))
    for i in range(len(val_data_list)):
        val_data_list[i] = val_data_list[i].map(lambda x: generate_and_tokenize_prompt(x))
    train_data = concatenate_datasets([_["train"] for _ in train_data_list])
    val_data = concatenate_datasets([_["train"] for _ in val_data_list])

    # train_data = train_data.shuffle(seed=42)[:sample] if sample > -1 else train_data
    # print(len(train_data))
    if resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            #model = set_peft_model_state_dict(model, adapters_weights)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")

    model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

    if not ddp and torch.cuda.device_count() > 1:
        model.is_parallelizable = True
        model.model_parallel = True



    ###################load the provider-adjcent matrix#####################3

    path = "data/utils/output_categoryid.json"
    with open(path, 'r') as json_file:
        data = json.load(json_file)

    group_id = list(data.values())

    item_corpus = list(data.keys())
    print(max(group_id))
    data_size = len(group_id)
    print(data_size)
    group_size = max(group_id) + 1
    AdjecentMatrix = np.zeros((data_size, group_size))
    for i,j in enumerate(group_id):
        AdjecentMatrix[i,j] = 1
    #exit(0)

    ######################Fairness paramteres#############333

    data_ids = np.arange(0, data_size)

    ###########reset paras#####################3
    train_len = len(train_data)
    #item_corpus = item_corpus
    item_corpus_len = len(item_corpus)


    Fair_Sampler = Dual_Sampler(p_size=group_size,
                                train_len=train_len,
                                AdjecentMatrix=AdjecentMatrix,
                                batch_size=micro_batch_size,
                                data_size=data_size)


    class FairTrainer(Trainer):
        def __init__(self, *args, **kwargs):
            super(FairTrainer, self).__init__(*args, **kwargs)
            self.label_smoother = WeightLabelSmoother(epsilon=self.args.label_smoothing_factor)

            self.tokenizer = tokenizer
            self.generation_config = GenerationConfig(
                # temperature=temperature,
                # top_p=top_p,
                # top_k=top_k,
                num_beams=1,
                num_return_sequences=1,
                **kwargs,
            )


            self.data_ids = data_ids
            #self.reset_fair_paras()
            self.get_item_emb()

        def sampler(self, sample_num = 100):
            return np.random.choice(self.data_ids, size=sample_num, replace=False)

        def get_item_emb(self):
            batch_size = 128
            store_emb = []
            for b in trange(int(np.ceil(item_corpus_len/batch_size))):
                min_id = b * batch_size
                max_id = min((b+1) * batch_size, item_corpus_len)
                text = item_corpus[min_id:max_id]
                emb = get_embedding(text,self.tokenizer, self.model)
                store_emb.append(emb)
            self.item_embeddings = torch.cat(store_emb, dim=0)


        def compute_loss(self, model, inputs, return_outputs=False):
            """
                    How the loss is computed by Trainer. By default, all models return the loss in the first element.

                    Subclass and override for custom behavior.
                    """
            if self.label_smoother is not None and "labels" in inputs:
                labels = inputs.pop("labels")
            else:
                labels = None
            outputs = model(**inputs, output_hidden_states=True)

            batch_size, seq_len = inputs['input_ids'].shape
            ids = inputs['input_ids']
            indices = torch.nonzero(ids == 13) #13 is the \n, which is used to divide the Reponses and the Input
            # print(indices.shape) # 输入是[140]
            indices = indices.reshape(-1, 9, 2) ## our prompt has 9 \n, we need to take the last \n to get the embeddings
            #print(indices)
            last_indices = indices[:,-1,:]
            embeddings = outputs['hidden_states'][-1]

            out_embeddings = []
            for i,j in last_indices:
                out_embeddings.append(torch.mean(embeddings[i,j+1:,:],dim=0,keepdim=True))
            out_embeddings = torch.cat(out_embeddings,dim=0)
            sample_id = self.sampler()
            pre_item_emb = self.item_embeddings[sample_id,:]

            #print(pre_item_emb.shape)
            #exit(0)
            ranking_scores = torch.matmul(out_embeddings.half(), pre_item_emb.t().half())
            values, items = torch.topk(ranking_scores, k=topk, dim=-1)
            items = items.cpu().numpy()
            items = sample_id[items]

            batch_weight, update_flag = Fair_Sampler.update_weight(items)
            #print(update_flag)
            if update_flag == True:
                self.get_item_emb()

            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index]

            if labels is not None:
                #print("hit label_smoother....")
                unwrapped_model = unwrap_model(model)
                if _is_peft_model(unwrapped_model):
                    model_name = unwrapped_model.base_model.model._get_name()
                else:
                    model_name = unwrapped_model._get_name()
                if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                    loss = self.label_smoother(outputs, labels, shift_labels=True, batch_weight=batch_weight)
                else:
                    loss = self.label_smoother(outputs, labels, batch_weight=batch_weight)
            else:
                if isinstance(outputs, dict) and "loss" not in outputs:
                    raise ValueError(
                        "The model did not return a loss from the inputs, only the following keys: "
                        f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                    )
                # We don't use .loss here since the model may return tuples instead of ModelOutput.
                #print("hit else")
                loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            #print(loss)
            #exit(0)

            return (loss, outputs) if return_outputs else loss


    trainer = FairTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            per_device_eval_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=20,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=True,
            logging_steps=20,
            optim="adamw_torch",
            evaluation_strategy="epoch",
            eval_steps=5,
            save_strategy="epoch",
            save_steps=20,
            output_dir=output_dir,
            save_total_limit=1,
            load_best_model_at_end=True,
            ddp_find_unused_parameters=False if ddp else None,
            group_by_length=group_by_length,
            # report_to=None,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
        ),
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )
    model.config.use_cache = False

    old_state_dict = model.state_dict
    model.state_dict = (
        lambda self, *_, **__: get_peft_model_state_dict(
            self, old_state_dict()
        )
    ).__get__(model, type(model))

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir)

    wandb.finish()

    print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )


def generate_prompt(data_point):
    # sorry about the formatting disaster gotta move fast
    if data_point["input"]:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 

### Instruction:
{data_point["instruction"]}

### Input:
{data_point["input"]}

### Response:
{data_point["output"]}"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  

### Instruction:
{data_point["instruction"]}

### Response:
{data_point["output"]}"""


if __name__ == "__main__":
    fire.Fire(train)