import copy
from dataclasses import dataclass, field
import json
import pathlib
from typing import Dict, Optional, Sequence
import random
import os
import numpy as np
import torch
import re
from datasets import load_dataset
from torch.utils.data import Dataset
import transformers
from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq
from transformers.trainer_pt_utils import LabelSmoother
import torch.nn.functional as F

IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import wandb

os.environ["WANDB_API_KEY"] = "local-9b0fa2120a27145f151a8d533cc5cd3208e823c1"
os.environ["WANDB_PROJECT"] = "AFT"

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")

@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )
    task: str = field(
        default='gsm8k', metadata={"help": "Path to the training data."}
    )

@dataclass
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    beta: float = field(default=0)
    length_penalty: float = field(default=1)
    gap: float = field(default=0)
    alignment_type: str = field(default="rrhf")
    seed: int = field(
        default=1,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    rank_start: int = field(
        default=0,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )

    neg_detach: int = field(
        default=1,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )

    rft: int = field(
        default=0,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )

    mlce_avg: int = field(
        default=0,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )

    save: int = field(
        default=1,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )

    inference: int = field(
        default=1,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )

    boundary: int = field(
        default=0,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )

    temperature: float = field(
        default=1,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )


    
    

local_rank = None
class RankTrainer(Seq2SeqTrainer):

    def gather_logits_labels(self, logits, labels):
        mask = (labels != -100).float()
        new_logits = logits.clone()  # B x S x hidsize
        new_labels = labels.clone()
        new_labels[labels == -100] = 0  # B x S 
        output = torch.gather(new_logits, dim=-1, index=new_labels.unsqueeze(-1)).squeeze(-1)
        output = output * mask # B * L
        return output

    def get_score(self, logit_label, labels):
        mask = (labels != -100).float()
        length = mask.sum(-1)
        scores = logit_label.sum(-1) / (length ** self.args.length_penalty)
        return scores
    
    def pro_loss(self, scores, idxs, rw_scores):
        
        sorted_indexs = rw_scores.argsort(dim=-1, descending=True)
        sorted_rw_scores = rw_scores[sorted_indexs]
        sorted_scores = scores[sorted_indexs]
        loss = 0
        for i in range(0, len(sorted_rw_scores) - 1):
            pos_scores = sorted_scores[i]
            pos_rw_score = sorted_rw_scores[i].item()
            neg_start = None

            for kk in range(i+1, len(sorted_rw_scores)):
                if sorted_rw_scores[kk].item() < pos_rw_score: # 第kk个样本reward model分数小于pos_rw_score
                    neg_start = kk
                    break
            if neg_start is None:
                continue

            neg_rw_scores = sorted_rw_scores[neg_start:]
            neg_scores = sorted_scores[neg_start:]

            neg_temperatures = (pos_rw_score - neg_rw_scores) / self.args.temperature
            pos_temperature = neg_temperatures.max()
           
            tmp = 0
            for neg_score, neg_temperature in zip(neg_scores, neg_temperatures):
                
                if self.args.alignment_type == 'pro_no_temp':
                    neg_temperature = 1
                    pos_temperature = 1
                assert pos_temperature >= neg_temperature

                if self.args.neg_detach:
                    tmp += torch.exp((neg_score.detach() * neg_temperature  - pos_scores * pos_temperature))
                else:
                    tmp += torch.exp((neg_score * neg_temperature  - pos_scores * pos_temperature))

            if tmp != 0:
                loss += torch.log(1 + tmp)
    
        return loss

    def rrhf_loss(self, scores, idxs, rw_scores):
        loss = 0
        sorted_indexs = rw_scores.argsort(dim=-1, descending=True)
        sorted_rw_scores = rw_scores[sorted_indexs]
        sorted_scores = scores[sorted_indexs]
        for i in range(0, len(sorted_rw_scores) - 1):
            pos_scores = sorted_scores[i]
            pos_rw_score = sorted_rw_scores[i].item()
            neg_start = None
            for kk in range(i+1, len(sorted_rw_scores)):
                if sorted_rw_scores[kk].item() < pos_rw_score: 
                    neg_start = kk
                    break
            if neg_start is None:
                continue
            for neg_score in sorted_scores[neg_start:]:
                if self.args.neg_detach:
                    loss += max(neg_score.detach() - pos_scores, 0)
                else:
                    loss += max(neg_score - pos_scores, 0)
        return loss
    
    def aft_binary_loss(self, scores, idxs, rw_scores):
        pos_indexs = []
        neg_indexs = []
        max_scores = torch.max(rw_scores, dim=-1)[0]
        tmp = 0
        lower_bounad = 1000
        positive_lower_boundary = 10000
        for idx, rw_score in enumerate(rw_scores):
            if rw_score.item() == max_scores.item():
                pos_indexs.append(idx)
                positive_lower_boundary = min(scores[idx].item(), positive_lower_boundary)
            else:
                neg_indexs.append(idx)
            lower_bounad = min(scores[idx].item(), lower_bounad)

        for neg_index in neg_indexs:
            for pos_index in pos_indexs:
                if self.args.neg_detach:
                    tmp += torch.exp((scores[neg_index].detach() - scores[pos_index])/self.args.temperature)
                else:
                    tmp += torch.exp((scores[neg_index] - scores[pos_index])/self.args.temperature)
    
                if self.args.boundary:
                    tmp += torch.exp(2*positive_lower_boundary - 2*self.args.beta - scores[pos_index].item() - scores[neg_index])

        if tmp != 0:
            loss = torch.log(1 + tmp)
        else:
            loss = 0

        return loss
   
    def aft_rank_loss(self, scores, idxs, rw_scores):
        tmp = 0
        sorted_indexs = rw_scores.argsort(dim=-1, descending=True)
        sorted_rw_scores = rw_scores[sorted_indexs]
        sorted_scores = scores[sorted_indexs]

        positive_lower_boundary = 10000
        for i in range(0, len(sorted_rw_scores) - 1):
            pos_scores = sorted_scores[i]
            pos_rw_score = sorted_rw_scores[i].item()
            neg_start = None

            positive_lower_boundary = min(pos_scores.item(), positive_lower_boundary)
            for kk in range(i+1, len(sorted_rw_scores)):
                if sorted_rw_scores[kk].item() < pos_rw_score: # 第kk个样本reward model分数小于pos_rw_score
                    neg_start = kk
                    break
            if neg_start is None:
                continue
                
            for neg_score in sorted_scores[neg_start:]:
                if self.args.neg_detach:
                    tmp += torch.exp((neg_score.detach() - pos_scores) / self.args.temperature)
                else:
                    tmp += torch.exp((neg_score - pos_scores) / self.args.temperature)
                if self.args.boundary:
                    tmp += torch.exp(2*positive_lower_boundary - 2*self.args.beta - pos_scores.item() - neg_score)

        if tmp != 0:
            loss = torch.log(1 + tmp)
        else:
            loss = 0
        return loss


    def sft_loss(self, logit_label, idxs, rw_scores): # TODO
        idxs_set = set(idxs.cpu().tolist())
        logit_label_sft = []
        for idx in idxs_set:
            idx_index = idxs == idx
            rw_scores_idx = rw_scores[idx_index]
            logit_label_idx = logit_label[idx_index]
            max_idx = torch.argmax(rw_scores_idx, dim=-1)
            logit_label_sft.append(logit_label_idx[max_idx])
        logit_label_sft = torch.stack(logit_label_sft, dim=0)
        return -logit_label_sft.mean()

    def rft_loss(self, logit_label, idxs, rw_scores): # TODO
        idxs_set = set(idxs.cpu().tolist())
        logit_label_sft = []
        for idx in idxs_set:
            idx_index = idxs == idx
            rw_scores_idx = rw_scores[idx_index]
            logit_label_idx = logit_label[idx_index]
            max_scores = torch.max(rw_scores_idx, dim=-1)[0]
            for idxj, rw_score in enumerate(rw_scores_idx):
                if rw_score.item() == max_scores.item():
                    logit_label_sft.append(logit_label_idx[idxj])
        
        logit_label_sft = torch.stack(logit_label_sft, dim=0)
        return -logit_label_sft.mean()
        
    def compute_loss(self, model, inputs, return_outputs=False): 
        logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] # (batch * cand) * L * V
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = inputs.get("labels")[..., 1:].contiguous() 
        logits = F.log_softmax(shift_logits, dim=-1)
        logit_label = self.gather_logits_labels(logits, shift_labels)
        scores = self.get_score(logit_label, shift_labels)
    
        if self.args.alignment_type == 'rrhf':
            alignment_loss = self.rrhf_loss(scores, inputs.get("idxs"), inputs.get("scores"))
        elif self.args.alignment_type == 'pro':
            alignment_loss = self.pro_loss(scores, inputs.get("idxs"), inputs.get("scores"))
        elif self.args.alignment_type == 'aft_binary':
            alignment_loss = self.aft_binary_loss(scores, inputs.get("idxs"), inputs.get("scores"))
        elif self.args.alignment_type == 'aft_rank':
            alignment_loss = self.aft_rank_loss(scores, inputs.get("idxs"), inputs.get("scores"))
        elif self.args.alignment_type == 'no':
            alignment_loss = 0
        else:
            assert 0

        if self.args.rft:
            sft_loss = self.rft_loss(logit_label, inputs.get("idxs"), inputs.get("scores"))
        else:
            sft_loss = self.sft_loss(logit_label, inputs.get("idxs"), inputs.get("scores"))
        
        loss = sft_loss + alignment_loss
  
        return (loss, scores) if return_outputs else loss

    @torch.no_grad()
    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only: bool,
        ignore_keys,
        **gen_kwargs,
    ):
        inputs = self._prepare_inputs(inputs)

        gen_config = {
            "max_new_tokens": 256,
            "pad_token_id": self.tokenizer.eos_token_id,
            "temperature": 0
        }

        generated_tokens = self.model.generate(**inputs, **gen_config)
        # print(generated_tokens.size())
        generated_tokens = generated_tokens[:, inputs['input_ids'].shape[1]:]
        # print(generated_tokens.size())
        generated_tokens = self._pad_tensors_to_max_len(generated_tokens, 2048 + 1)
        loss = None
        labels = inputs["labels"]
        labels = self._pad_tensors_to_max_len(labels, 2048 + 1)

        return loss, generated_tokens, labels

import json
class ScoreDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
        super(ScoreDataset, self).__init__()
        with open(data_path, 'r') as f:
            lines = f.readlines()
        datas = [json.loads(line.strip()) for line in lines]
        self.data = {}
        for idj, data in enumerate(datas):
            data['idx'] = data.get('idx', idj)
            idx = data['idx'] 
            if idx in self.data:
                self.data[idx].append(data)
            else:
                self.data[idx] = [data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return dict(input_ids=self.data[i])

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.eval_data_collator = DataCollatorForSeq2Seq(
            tokenizer,
            model=model,
            label_pad_token_id=-100,
            pad_to_multiple_of=8,
        )
        
    def __call__(self, batch_instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        if len(batch_instances[0].keys()) == 1:
            all_idxs = []
            all_scores = []
            all_input_ids = []
            all_labels = []
            for _, instances in enumerate(batch_instances):
                instances = instances['input_ids']
                for ins in instances: 
                    question = ins['question']
                    target = ins['answer'] + self.tokenizer.eos_token
                    idx = ins['idx']
                    scores = ins.get('scores', 1)
                    tokenize_question = self.tokenizer(question, truncation=True, max_length = self.tokenizer.model_max_length).input_ids
                    tokenize_output = self.tokenizer(target, truncation=True, max_length = self.tokenizer.model_max_length - len(tokenize_question)).input_ids[1:]
                    input_ids = torch.tensor(tokenize_question + tokenize_output).long()
                    labels = input_ids.clone() 
                    labels[: len(tokenize_question)] = -100

                    all_idxs.append(idx)
                    all_scores.append(scores)
                    all_input_ids.append(input_ids)
                    all_labels.append(labels)
                
            all_input_ids = torch.nn.utils.rnn.pad_sequence(
                all_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
            )
            all_labels = torch.nn.utils.rnn.pad_sequence(
                all_labels, batch_first=True, padding_value=-100
            )
            return dict(
                input_ids=all_input_ids,
                attention_mask=all_input_ids.ne(self.tokenizer.pad_token_id),
                labels=all_labels,
                idxs=torch.LongTensor(all_idxs),
                scores=torch.FloatTensor(all_scores),
            )

        else:
            return self.eval_data_collator(batch_instances)


def eval_tokenize_function(examples, tokenizer):
    results = {
        "input_ids": [],
        "labels": [],
        "attention_mask": []
    }
    for question, target in zip(examples['question'], examples['answer']):
        tokenize_question = tokenizer(question, truncation=True, max_length = tokenizer.model_max_length).input_ids
        tokenize_output = tokenizer(target, truncation=False).input_ids[1:]
        attention_mask = [1] * len(tokenize_question)
        results["input_ids"].append(tokenize_question)
        results["labels"].append(tokenize_output)
        results["attention_mask"].append(attention_mask)
    return results


def parse_answer(predicted_answer):
    try:
        predicted_answer = predicted_answer.split('####')[1].strip().strip('.')
    except:
        predicted_answer = '[Invaliad]'
    return predicted_answer

def train():
    global local_rank

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    os.environ['WANDB_NAME'] = f'{os.path.basename(training_args.output_dir)}'
    local_rank = training_args.local_rank
    transformers.set_seed(training_args.seed)
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )
    model.config.use_cache = False
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="left",
        use_fast=False,
    )
    tokenizer.pad_token = tokenizer.unk_token

    train_dataset = ScoreDataset(f"data/{data_args.task}/{data_args.data_path}.jsonl", tokenizer)
    raw_eval_dataset = load_dataset('json', data_files=f"data/{data_args.task}/dev.jsonl")['train']
    raw_test_dataset = load_dataset('json', data_files=f"data/{data_args.task}/test.jsonl")['train']

    print(f"train: {len(train_dataset)}; dev: {len(raw_eval_dataset)}; test: {len(raw_test_dataset)}")

    with training_args.main_process_first(desc="validation dataset map pre-processing"):
        eval_dataset = raw_eval_dataset.map(
                eval_tokenize_function,
                batched=True,
                num_proc=1,
                remove_columns=raw_eval_dataset.column_names,
                load_from_cache_file=False,
                desc="Running tokenizer on validation dataset",
                fn_kwargs={"tokenizer": tokenizer}
            )
        test_dataset = raw_test_dataset.map(
                eval_tokenize_function,
                batched=True,
                num_proc=1,
                remove_columns=raw_test_dataset.column_names,
                load_from_cache_file=False,
                desc="Running tokenizer on validation dataset",
                fn_kwargs={"tokenizer": tokenizer}
            )

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]

        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        right = 0
        for pred, label in zip(decoded_preds, decoded_labels):
            assert parse_answer(label) != '[Invaliad]'
            if parse_answer(pred) == parse_answer(label):
                right += 1

        result = {}
        result["accuracy"] = right / len(decoded_preds)
        with open(f'./log/{os.path.basename(training_args.output_dir)}_{len(preds)}.json', 'w') as f:
            f.write(json.dumps(result) + '\n')
            for i in range(len(decoded_preds)):
                try:
                    data = {
                        'idx': i,
                        'prediction': decoded_preds[i],
                        'answer': decoded_labels[i]
                    }
                    f.write(json.dumps(data) + '\n')
                except:
                    pass
        return result

    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, model=model)
    trainer = RankTrainer(
        model=model, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, compute_metrics=compute_metrics
    )
    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    if training_args.save:
        trainer.save_state()
        trainer.save_model(output_dir=training_args.output_dir)
        
    if training_args.inference:
        dev_results = trainer.evaluate(eval_dataset, metric_key_prefix='eval')
        print(dev_results)
        test_results = trainer.evaluate(test_dataset, metric_key_prefix='test')
        print(test_results)


if __name__ == "__main__":
    train()
