import os
import abc
import torch
import accelerate
from torch import nn
import torch.nn.functional as F
import numpy as np
import random
import re
from typing import Any, Callable, Optional, Union, Sequence
from itertools import islice
from transformers import PreTrainedTokenizer, AutoModelForCausalLM
from trl import GRPOTrainer
from trl import GRPOConfig
from .math_parsing_util import (
    extract_answer,
    math_equal,
    strip_answer_string,
)


def is_tensor(t):
    if isinstance(t, torch.Tensor):
        return True
    return False


def log_tensor_info(tensor):
    print(f"Shape: {tensor.shape}")
    print(f"Max: {tensor.max().item()}")
    print(f"Min: {tensor.min().item()}")
    print(f"Mean: {tensor.float().mean().item()}")
    print(f"Nan: {torch.isnan(tensor).any().item()}")


class CustomRewardTrainer(abc.ABC):
    def __init__(
        self,
        tokenizer,
        reward_functions=None,
        output_dir=None,
        logging_prob=0.0,
    ):
        if reward_functions is None:
            reward_functions = self.reward_funcs
        for rw in reward_functions:
            if isinstance(rw, CustomReward):
                rw.link_with_trainer(
                    
                    
                    trainer=self,
                    tokenizer=tokenizer,
                )
        self.logging_dir = output_dir
        self.logging_prob = logging_prob
        self.last_logged_iter = -1
        self.do_log_to_file = False
        if logging_prob > 0.0 and output_dir is not None:
            self.do_log_to_file = True
            self.logging_dir = output_dir + '/completions_logs'
            os.makedirs(self.logging_dir, exist_ok=True)
            self.logging_file = self.logging_dir + '/log.txt'

    def log_to_file(self, *args, **kwargs):
        if self.do_log_to_file:
            if not (self.state.global_step == self.last_logged_iter):
                with open(f"{self.logging_file}", "a") as f:
                    f.write("\n\n============================\n" +
                            f"Global step: {self.state.global_step}")
                self.last_logged_iter = self.state.global_step
            if random.random() < self.logging_prob:
                for log_value in args:
                    assert isinstance(log_value, str)
                    with open(f"{self.logging_file}", "a") as f:
                        f.write("\n\n==============\n" + log_value)
                for log_name, log_value in kwargs.items():
                    assert isinstance(log_value, str)
                    with open(f"{self.logging_dir}/{log_name}.txt", "a") as f:
                        f.write("\n\n==============\n" + log_value)

    def log_metric(self, **kwargs):
        logged_dict = {}
        for log_name, log_value in kwargs.items():
            if is_tensor(log_value):
                log_value = log_value.mean().item()
            elif isinstance(log_value, (list, tuple)):
                log_value = np.mean(log_value)
            else:
                log_value = float(log_value)
            logged_dict[log_name] = log_value
            if self.accelerator.is_main_process:
                self._metrics[log_name].append(log_value)
        return logged_dict


class CustomReward(abc.ABC):
    
    def link_with_trainer(
            self, trainer, tokenizer, numeric_answer=False):
        self.__name__ = self.__class__.__name__
        self._numeric_answer = numeric_answer
        self._format_rw, self._extract_solution = (
            self.make_answer_check_and_functions(numeric_answer=numeric_answer))

    def make_answer_check_and_functions(self, numeric_answer):
        def format_reward_fn(completions, start_think_tag, end_think_tag, start_solution_tag, end_solution_tag, **kwargs):
            rewards = []
            for completion in completions:
                try:
                    completion = f"{start_think_tag}" + completion
                    regex = f"^{re.escape(start_think_tag)}([\s\S]*?){re.escape(end_think_tag)}\n\n"
                    regex += f"{re.escape(start_solution_tag)}([\s\S]*?){re.escape(end_solution_tag)}$"
                    match = re.search(regex, completion, re.DOTALL)
                    rewards.append(1.0 if match and len(
                        match.groups()) == 2 else 0.0)
                except Exception:
                    rewards.append(0.0)
            return rewards

        def solution_extraction_fn(completion, start_think_tag, end_think_tag, start_solution_tag, end_solution_tag, **kwargs):
            match = re.search(
                f"{re.escape(start_solution_tag)}(.*?){re.escape(end_solution_tag)}", completion, re.DOTALL)
            hyp = match.group(1).strip()

            if numeric_answer:
                hyp = extract_answer(hyp)
                hyp = strip_answer_string(hyp)
            return hyp

        return [format_reward_fn, solution_extraction_fn]

    @abc.abstractmethod
    def __call__(
        self,
        prompts,
        completions,
        start_think_tag,
        end_think_tag,
        start_solution_tag,
        end_solution_tag,
        **kwargs,
    ):
        raise NotImplementedError


class CustomGRPOTrainer(GRPOTrainer, CustomRewardTrainer):
    def __init__(
            self,
            *args,
            logging_prob=0.0,
            **kwargs):

        GRPOTrainer.__init__(self, *args, **kwargs)

        CustomRewardTrainer.__init__(
            self,
            tokenizer=self.processing_class,
            reward_functions=self.reward_funcs,
            output_dir=self.args.output_dir,
            logging_prob=logging_prob,
        )
