import os
from abc import ABC

import torch
from torch.optim import Optimizer
from tqdm import tqdm

from openrlhf.models import GPTLMLoss
from openrlhf.utils.distributed_sampler import DistributedSampler
from torch import distributed as dist

from openrlhf.utils import match_with_answer_labels_v2, extract_last_answer
import re
import json
from peft import PeftModelForCausalLM

class SFTTrainer(ABC):
    """
    Trainer for supervised fine-tuning (SFT).

    Args:
        model (torch.nn.Module): The model to be trained.
        strategy (Strategy): The training strategy to be aRePOied.
        optim (Optimizer): The optimizer for model training.
        train_dataloader (DataLoader): The dataloader for the training dataset.
        eval_dataloader (DataLoader): The dataloader for the evaluation dataset.
        scheduler (Scheduler): The learning rate scheduler to adjust training rates.
        max_norm (float, defaults to 1): Maximum gradient norm for clipping to prevent exploding gradients.
        pretrain_mode (bool, defaults to False): Flag to indicate if the trainer is in pre-training mode.
        batch_size (int, defaults to 1): Batch size for training.
        max_epochs (int, defaults to 2): The maximum number of training epochs.
        tokenizer (Tokenizer, optional): The tokenizer for processing input data.
        save_hf_ckpt (bool): Whether to save huggingface-format model weight.
        disable_ds_ckpt (bool): Whether not to save deepspeed-format model weight. (Deepspeed model weight is used for training recovery)
    """

    def __init__(
        self,
        model,
        strategy,
        optim: Optimizer,
        train_dataloader,
        eval_dataloader,
        scheduler,
        max_norm: float = 1,
        pretrain_mode: bool = False,
        batch_size: int = 1,
        max_epochs: int = 2,
        tokenizer=None,
        save_hf_ckpt: bool = False,
        disable_ds_ckpt: bool = False,
    ) -> None:
        super().__init__()
        self.strategy = strategy
        self.epochs = max_epochs
        self.batch_size = batch_size
        self.max_norm = max_norm
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.scheduler = scheduler
        self.pretrain_mode = pretrain_mode
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optim
        self.args = strategy.args
        self.save_hf_ckpt = save_hf_ckpt
        self.disable_ds_ckpt = disable_ds_ckpt

        self.loss_fn = GPTLMLoss(ring_attn_group=self.strategy.ring_attn_group)

        # Mixtral 8*7b
        self.aux_loss = self.args.aux_loss_coef > 1e-8

        # packing samples
        self.packing_samples = strategy.args.packing_samples

        self.max_global_step = None
        self.ratio_save_steps = None
        self.saved_ratio_steps = set()

        # wandb/tensorboard setting
        self._wandb = None
        self._tensorboard = None
        if self.strategy.args.use_wandb and self.strategy.is_rank_0():
            import wandb

            self._wandb = wandb
            if not wandb.api.api_key:
                wandb.login(key=strategy.args.use_wandb)
            wandb.init(
                entity=strategy.args.wandb_org,
                project=strategy.args.wandb_project,
                group=strategy.args.wandb_group,
                name=strategy.args.wandb_run_name,
                config=strategy.args.__dict__,
                reinit=True,
            )

            wandb.define_metric("train/global_step")
            wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
            wandb.define_metric("eval/global_step")
            wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True)

        # Initialize TensorBoard writer if wandb is not available
        if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
            log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
            self._tensorboard = SummaryWriter(log_dir=log_dir)

    def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None):
        # import pdb
        # pdb.set_trace()
        # get eval and save steps
        print_examples=True
        if args.eval_steps == -1:
            args.eval_steps = num_update_steps_per_epoch  # Evaluate once per epoch
        if args.eval_steps == -2:
            args.eval_steps = float("inf")  # Evaluate once per epoch

        if args.save_steps == -1:
            args.save_steps = num_update_steps_per_epoch  # do not save ckpt
        if args.save_steps == -2:
            args.save_steps = float("inf")  # do not save ckpt

        # Restore step and start_epoch
        step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1
        start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch
        consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size)

        # --------- calculate max_global_step and ratio-based save step ---------  ###
        # global_step = optimizer step
        # assume one epoch has num_update_steps_per_epoch optimizer steps
        self.max_global_step = self.epochs * num_update_steps_per_epoch

        # save_ratios = [0.5]
        save_ratios=None
        if save_ratios:
            # int() down. minimum 1 step is max(1, ...)
            ratio_steps = {max(1, int(self.max_global_step * r)) for r in save_ratios}
            # if same value, remove duplicates and sort
            self.ratio_save_steps = sorted(ratio_steps)
            self.saved_ratio_steps = set()

            if self.strategy.is_rank_0():
                self.strategy.print(
                    f"[SFTTrainer] max_global_step = {self.max_global_step}, "
                    f"ratio_save_steps (0.5) = {self.ratio_save_steps}"
                )
        else:
            self.ratio_save_steps = None
            self.saved_ratio_steps = None
            # ------------------------------------------------------------------  ###
        
        epoch_bar = tqdm(
            range(start_epoch, self.epochs),
            desc="Train epoch",
            disable=not self.strategy.is_rank_0(),
        )
        loss_sum = 0
        for epoch in range(start_epoch, self.epochs):
            if isinstance(self.train_dataloader.sampler, DistributedSampler):
                self.train_dataloader.sampler.set_epoch(
                    epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples
                )

            step_bar = tqdm(
                range(self.train_dataloader.__len__()),
                desc="Train step of epoch %d" % epoch,
                disable=not self.strategy.is_rank_0(),
            )

            # train
            self.model.train()
            for prompt_id_lens, inputs, attention_masks, answers, infos in self.train_dataloader:
                
                if print_examples:
                    if self.strategy.is_rank_0():
                        self.strategy.print(f"\n Training data example: \n{self.tokenizer.decode(inputs[0][0], skip_special_tokens=True)}\n")
                    print_examples=False
                
                if self.packing_samples:
                    inputs = inputs.to(torch.cuda.current_device())
                    attention_mask = attention_masks.to(torch.cuda.current_device())
                else:
                    inputs = inputs.to(torch.cuda.current_device()).squeeze(1)
                    attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1)

                if self.strategy.ring_attn_group is None:
                    output = self.model(inputs, attention_mask=attention_mask, return_output=True)
                else:
                    output = self.model(
                        inputs,
                        attention_mask=attention_mask,
                        return_output=True,
                        ring_attn_group=self.strategy.ring_attn_group,
                        packed_seq_lens=infos["input_length"],
                    )

                # loss function
                labels = torch.where(
                    attention_mask.bool(),
                    inputs,
                    self.loss_fn.IGNORE_INDEX,
                )
                # mixtral
                if self.aux_loss:
                    aux_loss = output.aux_loss
                else:
                    aux_loss = 0

                if not self.pretrain_mode:
                    if self.packing_samples:
                        # As response_ranges need to constrain the dataset organization strictly, we handle multiturn feature separately.
                        if infos["response_ranges"]:
                            dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device)
                            for response_ranges in infos["response_ranges"]:
                                for response_range in response_ranges:
                                    dump_labels[0][response_range[0] : response_range[1] + 1] = labels[0][
                                        response_range[0] : response_range[1] + 1
                                    ]
                            labels = dump_labels
                        else:
                            index = 0
                            for input_length, source_len in zip(infos["input_length"], prompt_id_lens):
                                labels[0][index : index + source_len + 1] = self.loss_fn.IGNORE_INDEX
                                index += input_length
                    else:
                        for label, source_len in zip(labels, prompt_id_lens):
                            label[: source_len + 1] = self.loss_fn.IGNORE_INDEX

                gpt_loss = self.loss_fn(output.logits, labels)
                loss = gpt_loss + aux_loss * self.args.aux_loss_coef
                self.strategy.backward(loss, self.model, self.optimizer)
                self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler)

                loss_sum += gpt_loss.item()
                logs_dict = {
                    "gpt_loss": gpt_loss.item(),
                    "lr": self.scheduler.get_last_lr()[0],
                }
                if self.aux_loss:
                    logs_dict["aux_loss"] = aux_loss.item()
                # step bar
                logs_dict = self.strategy.all_reduce(logs_dict)
                step_bar.set_postfix(logs_dict)
                step_bar.update()

                # logs/checkpoints/evaluation
                if step % self.strategy.accumulated_gradient == 0:
                    logs_dict["loss_mean"] = loss_sum / self.strategy.accumulated_gradient
                    loss_sum = 0
                    global_step = step // self.strategy.accumulated_gradient
                    client_states = {"consumed_samples": global_step * args.train_batch_size}
                    self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states)

                step += 1

            epoch_bar.update()

        if self._wandb is not None and self.strategy.is_rank_0():
            self._wandb.finish()
        if self._tensorboard is not None and self.strategy.is_rank_0():
            self._tensorboard.close()

    # logs/checkpoints/evaluation
    def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
        if global_step % args.logging_steps == 0:
            # wandb
            if self._wandb is not None and self.strategy.is_rank_0():
                logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()}
                self._wandb.log(logs)
            # TensorBoard
            elif self._tensorboard is not None and self.strategy.is_rank_0():
                for k, v in logs_dict.items():
                    self._tensorboard.add_scalar(f"train/{k}", v, global_step)

        # eval
        if args.eval_steps == -2: # pass evaluation
            pass
        elif global_step % args.eval_steps == 0:
            # do eval when len(dataloader) > 0, avoid zero division in eval.
            if len(self.eval_dataloader) > 0:
                self.evaluate(self.eval_dataloader, global_step)
        # ---------------- ratio-based save logic ------------------------------  ###
        save_by_ratio = False
        if self.ratio_save_steps is not None:
            if global_step in self.ratio_save_steps and global_step not in self.saved_ratio_steps:
                save_by_ratio = True
                self.saved_ratio_steps.add(global_step)

        # (optional) if you want to keep the old way, only use it when needed
        save_by_step = False
        if args.save_steps not in (float("inf"), 0) and args.save_steps > 0:
            if global_step % args.save_steps == 0:
                save_by_step = True

        # here, "ratio-based save" is the main requirement,
        # usually save_by_ratio is enough and save_by_step is turned off.
        if save_by_ratio or save_by_step:
            tag = f"global_step{global_step}"
            if self.strategy.is_rank_0():
                self.strategy.print(
                    f"[SFTTrainer] Saving checkpoint at global_step={global_step} "
                    f"(ratio-based={save_by_ratio}, step-based={save_by_step})"
                )
            if not self.disable_ds_ckpt:
                self.strategy.save_ckpt(
                    self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states
                )
            if self.save_hf_ckpt:
                save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
                self.strategy.save_model(self.model, self.tokenizer, save_path)

        #         # merge and save peft model
        #         if args.save_merged and isinstance(self.strategy._unwrap_model(self.model), PeftModelForCausalLM):
        #             self.strategy.print("\nSave merged model...\n")

        #             merged_save_path = save_path + "_merged"
        #             if self.strategy.is_rank_0():
        #                 os.makedirs(merged_save_path, exist_ok=True)
        #                 self.tokenizer.save_pretrained(merged_save_path)
        #                 #model_to_merge = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(args.pretrain, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16), args.save_path)
        #                 model_to_merge = self.strategy._unwrap_model(self.model)
        #                 merged_model = model_to_merge.merge_and_unload()
        #                 merged_model.save_pretrained(save_directory=merged_save_path)
        #                 # save config
        #                 output_config_file = os.path.join(merged_save_path, "config.json")
        #                 merged_model.config.to_json_file(output_config_file)
        # dist.barrier()
        # torch.cuda.synchronize()            
            # ------------------------------------------------------------------  ###
        # # save ckpt
        # # TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric
        # if global_step % args.save_steps == 0:
        #     tag = f"global_step{global_step}"
        #     if not self.disable_ds_ckpt:
        #         self.strategy.save_ckpt(
        #             self.model.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states
        #         )
        #     if self.save_hf_ckpt:
        #         save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
        #         self.strategy.save_model(self.model, self.tokenizer, save_path)

    def evaluate(self, eval_dataloader, steps=0):
        times = 0
        acc = 0
        self.model.eval()
        with torch.no_grad():
            loss_sum = 0
            step_bar = tqdm(
                range(eval_dataloader.__len__()),
                desc="Eval stage of steps %d" % steps,
                disable=not self.strategy.is_rank_0(),
            )

            for prompt_id_lens, inputs, attention_masks, answers, infos in eval_dataloader:
                if self.packing_samples:
                    inputs = inputs.to(torch.cuda.current_device())
                    attention_mask = attention_masks.to(torch.cuda.current_device())
                else:
                    inputs = inputs.to(torch.cuda.current_device()).squeeze(1)
                    attention_mask = attention_masks.to(torch.cuda.current_device()).squeeze(1)

                if self.strategy.ring_attn_group is None:
                    output = self.model(inputs, attention_mask=attention_mask, return_output=True)
                else:
                    output = self.model(
                        inputs,
                        attention_mask=attention_mask,
                        return_output=True,
                        ring_attn_group=self.strategy.ring_attn_group,
                        packed_seq_lens=infos["input_length"],
                    )

                # loss function
                labels = torch.where(
                    attention_mask.bool(),
                    inputs,
                    self.loss_fn.IGNORE_INDEX,
                )

                if not self.pretrain_mode:
                    if self.packing_samples:
                        if infos["response_ranges"]:
                            dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device)
                            for response_ranges in infos["response_ranges"]:
                                for response_range in response_ranges:
                                    dump_labels[0][response_range[0] : response_range[1]] = labels[0][
                                        response_range[0] : response_range[1]
                                    ]
                            labels = dump_labels
                        else:
                            index = 0
                            for input_length, source_len in zip(infos["input_length"], prompt_id_lens):
                                labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX
                                index += input_length
                    else:
                        for label, source_len in zip(labels, prompt_id_lens):
                            label[:source_len] = self.loss_fn.IGNORE_INDEX

                loss = self.loss_fn(output.logits, labels)

                times += 1
                loss_sum += loss.item()
                bar_dict = {"eval gpt_loss": loss_sum / times}

                
                
                # accuracy if answer exist:
                    
                if self.strategy.args.eval_acc:

                    # _, _, _, _, _, _, _, _, prompt_ids, prompt_masks, answer_label, extra = data
                    prompt_ids = infos["input"].squeeze(1).to(torch.cuda.current_device())
                    # prompt_masks = infos["prompt_masks"].squeeze(1).to(torch.cuda.current_device())


                    #TODO: fix generate function! refer to data generation step in RePO old code
                    
                    # generated_outputs = self.model.generate(prompt_ids, prompt_masks)
                    # model_input_for_generation = {"input_ids": prompt_ids, "attention_mask": prompt_masks}
                    generated_outputs, _, _ = self.model.generate(
                                            input_ids=prompt_ids,
                                            # attention_mask=prompt_masks,
                                            use_cache=True,
                                            max_length=self.strategy.args.max_len,
                                            do_sample=True,
                                            top_p=self.strategy.args.top_p,
                                            early_stopping=False,
                                            num_beams=1,
                                            temperature=self.strategy.args.temperature,
                                            repetition_penalty=self.strategy.args.repetition_penalty,
                                            pad_token_id=self.tokenizer.pad_token_id,
                                            eos_token_id=self.tokenizer.eos_token_id,
                                        )
                    tokenized_output = self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
                    # chosen_reward, reject_reward, _ = self.concatenated_forward(
                    #     self.model, chosen_ids, c_mask, reject_ids, r_mask
                    # )
                    # import pdb
                    # pdb.set_trace()
                    if self.strategy.args.generation_log_path:
                        # tokenized_output = self.strategy.all_gather(tokenized_output)
                        # gathered_answers = self.strategy.all_gather(answers)
                        # save generation log
                        if self.strategy.is_rank_0():
                            os.makedirs(self.strategy.args.generation_log_path, exist_ok=True)
                            save_path = os.path.join(self.strategy.args.generation_log_path, f"eval_{steps}.jsonl")
                            with open(save_path, 'a') as f:
                                for generation, answer in zip(tokenized_output, answers):
                                    generation_dict = {"generation": generation, 
                                                        # "extracted_answers": self.extract_first_numeric_answer(generation, self.strategy.args.answer_trigger),
                                                       "extracted_answers": extract_last_answer(generation, self.strategy.args.answer_trigger),
                                                       "gold_answers": answer}
                                    f.write(json.dumps(generation_dict, ensure_ascii=False) + "\n")
                    
                    # acc += self.match_with_answer_labels(tokenized_output, answers)
                    acc += match_with_answer_labels_v2(tokenized_output, answers, self.strategy.args.answer_trigger)
                    bar_dict = {"eval accuracy": acc / times}
                
                step_bar.update()
                logs = self.strategy.all_reduce(bar_dict)
                step_bar.set_postfix(logs)
                

            if self.strategy.is_rank_0():
                if self._wandb is not None:
                    logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()}
                    self._wandb.log(logs)
                elif self._tensorboard is not None:
                    for k, v in logs.items():
                        self._tensorboard.add_scalar(f"eval/{k}", v, steps)
        self.model.train()  # reset model state

    
    
    def match_with_answer_labels(self, tokenized_output, answers):
        #TODO: current case is only work for gsm8k. find appropriate match with MATH or else.
        answer_trigger = self.strategy.args.answer_trigger
        correct_count = 0
        valid_count = 0
        for output, answer in zip(tokenized_output, answers):
            if answer is not None:
                # predicted_answer = self.extract_answer(output, answer_trigger)
                predicted_answer = self.extract_first_numeric_answer(output, answer_trigger)
                if predicted_answer is None:
                    continue
                is_correct = self.check_correctness(predicted_answer, answer)
                correct_count += is_correct
                valid_count += 1
        
        return correct_count/valid_count if valid_count > 0 else 0
    

    def extract_first_numeric_answer(self, text:str, answer_trigger:str):
        import re
        # matches = []

        # pattern 1: \(\\boxed{ANSWER}\)
        # match1 = re.search(r'boxed\{(.*?)\}', text)
        # if match1:
        #     matches.append(('boxed', match1.start(), match1.group(1).strip()))

        # pattern 2: Therefore, the answer is: ANSWER.
        pattern = re.escape(answer_trigger) + r"\s*['\"]?(\d+(?:\.\d+)?)['\"]?"
        # match2 = re.search(r'Therefore, the answer is: ([^\.\n]+)', text)
        matches = re.findall(pattern, text)

        if not matches:
            return None

        answer = matches[-1].strip()
        return float(answer) if re.match(r'^\d+(\.\d+)?$', answer) else None
    # # extract first answer
    # first = min(matches, key=lambda x: x[1])
    # answer_text = first[2]

    # # float / int extract
    # num_match = re.search(r'\d+(?:\.\d+)?', answer_text)
    # return float(num_match.group()) if num_match else None


    # def extract_first_numeric_answer(self, text:str, answer_trigger:str):
    #     matches = []

    #     # pattern 1: \(\\boxed{ANSWER}\)
    #     match1 = re.search(r'\\\(\\boxed\{(.*?)\}\\\)', text)
    #     if match1:
    #         matches.append(('boxed', match1.start(), match1.group(1).strip()))

    #     # pattern 2: Therefore, the answer is: ANSWER.
    #     pattern2=re.escape(answer_trigger)+r'\s*([^\.\n]+)'
    #     # match2 = re.search(r'Therefore, the answer is: ([^\.\n]+)', text)
    #     match2 = re.search(pattern2, text)
    #     if match2:
    #         matches.append(('therefore', match2.start(), match2.group(1).strip()))

    #     if not matches:
    #         return None

    #     # extract first answer
    #     first = min(matches, key=lambda x: x[1])
    #     answer_text = first[2]

    #     # float / int extract
    #     num_match = re.search(r'\d+(?:\.\d+)?', answer_text)
    #     return float(num_match.group()) if num_match else None
        
        
    def check_correctness(self, prediction, target):
        return abs(float(prediction) - float(target)) <= 1e-3
        
