import logging
import os
# import openai
import pickle
import re
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
from bs4 import BeautifulSoup
from datasets import load_dataset
from filelock import FileLock
from tqdm import tqdm
from transformers import GPT2Tokenizer


# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


def generate_w_ppo_trainer(ppo_trainer, tokenizer, input_str, gen_kwargs):
    input_tokens = tokenizer.encode(input_str, return_tensors="pt")[0]
    with torch.no_grad():
    #    output_tokens = ppo_trainer.model.generate(input_tokens, **gen_kwargs)
    # output = input_str.split(' ')[-1] + tokenizer.decode(torch.argmax(output_tokens[1]['logits'], dim=-1)[0])
        output_tokens = ppo_trainer.generate(input_tokens,
                                             batch_size=len(input_tokens),
                                             return_prompt=False,
                                             **gen_kwargs)
    output = tokenizer.decode(output_tokens.squeeze(), skip_special_tokens=True)
    return output


@dataclass
class ArxivArticle:
    text: str
    entry_id: str
    title: str
    sections: Dict[str, str]
    context_type: str = None

    def __post_init__(self):
        self.id = self.entry_id.split("/")[-1]

    def __repr__(self):
        return f"ArxivArticle: {self.title}\n\n"


@dataclass
class ConversationArticle(ArxivArticle):
    num_rounds: int = 0
    prompt: str = None
    context: List[Tuple[str, str]] = None

    def __repr__(self):
        return f"ConversationArticle: {self.entry_id}\n\n"


@dataclass
class Conversation:
    id: str
    context: List[Tuple[str, str]]


@dataclass
class ArxivContext:
    text: str
    entry_id: str
    context: str

    def __post_init__(self):
        self.id = self.entry_id.split("/")[-1]

    def __repr__(self):
        return f"ArxivContext:\n --{self.context}\n\n"


class ArxivContextManager:
    """
        Loading arxiv articles, and process the article to sections.
        Obtaining the context of interest and do the partially masking (optional).

        Args:
            - mask_method: "Random", "self-info-sent" or "no". Randomly mask the context or mask the context based on the perplexity.
    """

    def __init__(
            self,
            dataset,
            compressor,
            tokenizer,
    ):
        self.dataset = dataset
        self.compressor = compressor
        self.tokenizer = tokenizer
        self.max_token_len = 10000000

        self.llama_gen_kwargs = {
            "do_sample": False,  # yes, we want to sample
            "pad_token_id": tokenizer.eos_token_id,
            "eos_token_id": tokenizer.eos_token_id,
            "repetition_penalty": 1.1,
            "max_new_tokens": 512,  # specify how many tokens you want to generate at most
        }

        self.load_articles(dataset)

    def load_articles(self, dataset) -> List[ArxivArticle]:
        self.articles = []
        for article in dataset:
            entry_id = article["entry_id"]
            title = article["title"]
            text = article["text"]

            # remove anything before introduction
            text = re.sub(r"^.*?(§)", r"\1", text, flags=re.DOTALL)

            # split article into sections
            sections = re.split(r"(?<!§\.)§\s", text)
            sections = [self.beautify_context(section) for section in sections if section.strip()]

            self.articles.append(ArxivArticle(text=text, entry_id=entry_id, title=title, sections=sections))

        logging.info(f">> Finish preprocessing Arxiv articles. Loaded {len(self.articles)} documents.")

    def beautify_context(self, context: str) -> str:
        context = context.replace("<cit.>", '').replace('<ref>', '')
        context = re.sub(r"\s+", " ", context)
        context = re.sub(r"\n+", " ", context)
        return context

    def varify_context_length(self, context: str) -> bool:
        if context is None:
            return False
        num_tokens = len(self.tokenizer(context)['input_ids'])
        if num_tokens > self.max_token_len:
            return False
        return True

    def generate_context(self, num_articles: int = None, compressor_name='sp-compressor') -> dict[
        str, list[ArxivContext]]:
        orig_contexts = []
        compressed_contexts = []

        if num_articles is None or num_articles > len(self.articles):
            num_articles = len(self.articles)

        idx = 0
        for article in tqdm(self.articles, desc="Generating contexts"):
            idx += 1
            if len(orig_contexts) >= num_articles:
                break
            if len(article.sections) < 1:
                continue
            # if not self.varify_context_length(self.beautify_context(article.sections[0])):
            #     continue

            context = article.sections[0]  # introduction section
            context = self.beautify_context(context)

            # TODO: may need an initial prompt
            # TODO: the code differs according to compressor
            # note； weird cuda error for a certain sample between 70-100, skip for now
            if idx in list(range(70, 100)):
                continue

            if compressor_name == 'selective_content':
                compressed_context = self.compressor(context, reduce_ratio=0.5)[0]
            elif compressor_name == 'sp-compressor':
                # ppo trainer in default
                prompt_tmp = "[INST]Compress following content while maintaining the meaning.\n{}[/INST]"
                context = prompt_tmp.format(context)
                compressed_context = generate_w_ppo_trainer(self.compressor, self.tokenizer, context, self.llama_gen_kwargs)
            elif compressor_name == 'lingua':
                raise NotImplementedError

            orig_contexts.append(context)
            compressed_contexts.append(compressed_context)

        logging.info(f"Finish generating {len(orig_contexts)} contexts.")
        return {'orig': orig_contexts, 'compressed': compressed_contexts}

    def _check_point(self, message='') -> bool:
        pickle_file = os.path.join(self.path,
                                   f"{self.__class__.__name__}_{'sent' if self.sent_level_self_info else 'paragraph'}.pkl")
        logging.info(f"saved to {pickle_file}. {message}")
        print(f"saved to {pickle_file}. {message}")
        with FileLock(pickle_file + ".lock", timeout=100):
            with open(pickle_file, "wb") as f:
                pickle.dump(self, f)

    @classmethod
    def from_checkpoint(cls, pickle_path, **kwargs):
        with FileLock(pickle_path + ".lock", timeout=100):
            with open(pickle_path, 'rb') as f:
                manager = pickle.load(f)
        for k, v in kwargs.items():
            setattr(manager, k, v)
        manager._prepare_self_info()
        return manager


class ConversationContextManager(ArxivContextManager):

    def load_articles(self, dataset):
        self.articles = []
        for data in dataset:
            conversation = self._parse_conversation(data)
            article = self._build_article(conversation)
            if article is not None and article.num_rounds >= 4:
                self.articles.append(article)

    def generate_context(self, num_articles: int = None, compressor_name='sp-compressor'):
        orig_contexts = []
        compressed_contexts = []

        # self.max_token_len = 5000

        if num_articles is None or num_articles > len(self.articles):
            num_articles = len(self.articles)

        idx = 0
        for article in tqdm(self.articles[: num_articles], desc="Generating contexts"):
            idx += 1
            if not self.varify_context_length(self.beautify_context(article.sections[0])):
                continue
            if self.contain_zh(self.beautify_context(article.sections[0])):
                continue

            context = article.sections[0]  # introduction section
            context = self.beautify_context(context)

            # TODO: may need an initial prompt
            # TODO: the code differs according to compressor
            # note； weird cuda error for a certain sample between 70-100, skip for now

            if compressor_name == 'selective_content':
                compressed_context = self.compressor(context, reduce_ratio=0.5)[0]
            elif compressor_name == 'sp-compressor':
                # ppo trainer in default
                prompt_tmp = '[INST]' + 'Compress following prompt while maintaining the meaning:\n{}' + '[/INST]'
                context = prompt_tmp.format(context)
                compressed_context = generate_w_ppo_trainer(self.compressor, self.tokenizer, context)
            elif compressor_name == 'lingua':
                raise NotImplementedError

            orig_contexts.append(context)
            compressed_contexts.append(compressed_context)

        logging.info(f"Finish generating {len(orig_contexts)} contexts.")
        return {'orig': orig_contexts, 'compressed': compressed_contexts}

    def _build_article(self, conversation: Conversation):
        lines = []
        for utterence in conversation.context:
            line = f"{utterence[0]}: {utterence[1]}"
            punt = line[-1]
            if punt not in ['.', '!', '?']:
                line += '.'
            lines.append(line)
        content = '\n'.join(lines[:-1])
        if not self.varify_context_length(content):
            return None

        last_response = conversation.context[-1][1]
        article = ConversationArticle(text=content, entry_id=conversation.id, title='',
                                      sections=[content, last_response], num_rounds=len(conversation.context),
                                      context=conversation.context)
        return article

    def _parse_conversation(self, conversation):
        id = conversation['id']
        convs = []
        for sent in conversation['chat']:
            role = sent[0]
            if role != 'human':
                bsobj = BeautifulSoup(sent[1])
                for tag_name in ['p', 'br', 'div', 'li', 'h1', 'h2', 'h3', ]:
                    for tag in bsobj.find_all(tag_name):
                        if tag.string is not None:
                            tag.string.replace_with(tag.string + ' ')
                value = bsobj.get_text()
            else:
                value = sent[1]
            convs.append((role, value))
        return Conversation(id, convs)

    def contain_zh(self, text):
        traditional_pattern = re.compile(r"[\u3400-\u4DB5\u4E00-\u9FFF]+")  # Example Traditional Chinese range
        simplified_pattern = re.compile(r"[\u4E00-\u9FFF]+")  # Example common CJK range
        if traditional_pattern.search(text) or simplified_pattern.search(text):
            return True
        return False


class GSMContextManager:
    pass


class BBHContextManager:
    pass


if __name__ == "__main__":
    from selective_context import SelectiveContext

    sc = SelectiveContext(model_type='gpt2', lang='en')
    dataset_type = "conversation"

    datasets = {
        "Alpaca": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/alpaca-gpt4', split="train"),
        "ShareGPT": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/sharegpt-500', split="train"),
        "Arxiv": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/arxiv-march-2023', split="train"),
        "GSM8K": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/qwedsacf-grade-school-math-instructions',
                              split="train"),
        # "BBH": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/lukaemon-bbh/bbh.py', 'boolean_expressions'),
    }

    dataset_manager = {
        'arxiv': ArxivContextManager,
        'conversation': ConversationContextManager,
        'math': GSMContextManager,
        'bbh': BBHContextManager,
    }

    context_manager = dataset_manager[dataset_type](datasets['ShareGPT'], sc, GPT2Tokenizer.from_pretrained(
        '/home/hmp/hmp-mh/trl/hf_hub/models/gpt2'))
    contexts = context_manager.generate_context(num_articles=300, compressor_name='selective_content')
    print(contexts)
