import json
import logging
import os
import pathlib
import subprocess
from dataclasses import dataclass
from pyrouge import Rouge155

import torch
from argparse_dataclass import ArgumentParser
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from tqdm import tqdm
from transformers import LogitsProcessor, LogitsProcessorList

from quick_extend.dataset.booksum import BookSumDataset
from quick_extend.models.load_model import load_model, ModelConfig
from quick_extend.scripts.train import seed


@dataclass
class EvalConfig(ModelConfig):
    overwrite: bool = False
    disable_prompt: bool = False
    max_tokens: int = 256
    no_sample: bool = False


class StopAfterStringIsGenerated(LogitsProcessor):
    def __init__(self, base_len: int, tokenizer):
        super().__init__()

        self.base_len = base_len
        self.tokenizer = tokenizer

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if input_ids.size(1) > self.base_len:
            decoded = self.tokenizer.batch_decode(input_ids[:, self.base_len:])
            ends_with_answer = torch.tensor([s.endswith("</s>") for s in decoded], device=scores.device)
            forced_eos = torch.full((scores.size(1),), -float("inf"), device=scores.device)
            forced_eos[self.tokenizer.eos_token_id] = 0

            # Force generation of EOS after a space
            scores[ends_with_answer] = forced_eos
        return scores


PROMPT_FIRST_ONLY = os.getenv('PROMPT_FIRST_ONLY', '1') == '1'


def generate_summary(args, model, tokenizer, device, idx, item, out_dir):
    PROMPT_ALWAYS_FLASH = os.environ.get('PROMPT_ALWAYS_FLASH', '0') == '1'
    LAST_DENSE = os.environ.get('LAST_DENSE', '0') == '1'

    inputs, completion = item

    if (out_dir / f"out_{idx}.txt").exists() and not args.overwrite:
        with open(out_dir / f"out_{idx}.txt", 'r') as f:
            return f.read()

    tokenizer.truncation_side = 'left'

    assert hasattr(model, 'config')
    assert hasattr(model.config, 'max_position_embeddings')

    if not args.disable_prompt:
        if 'llama13b_32k' in args.model.lower():
            prompt = f"[INST] <<SYS>>\nSummarize the following text in about 300 words\n<</SYS>>[INST] {tokenizer.decode(inputs, skip_special_tokens=True)} [/INST]"
        else:
            if PROMPT_FIRST_ONLY:
                prompt = f"Summarize the following text in about 300 words:\n\n{tokenizer.decode(inputs, skip_special_tokens=True)}"
            else:
                prompt = f"Summarize the following text in about 300 words:\n\n{tokenizer.decode(inputs, skip_special_tokens=True)} The summary of previously given text is following."
    else:
        prompt = tokenizer.decode(inputs, skip_special_tokens=True)
    if prompt.endswith('</s>'):
        prompt = prompt[:-4]

    inputs = tokenizer(
        prompt,
        return_tensors='pt',
        max_length=model.config.max_position_embeddings - args.max_tokens,
        truncation=True,
    )['input_ids'][0]

    seq_len = inputs.shape[-1]
    print(f"seq_len: {seq_len}")

    if PROMPT_ALWAYS_FLASH:
        for m in model.modules():
            if hasattr(m, 'attention_method'):
                m.tree_dense_queries = seq_len - 1
    if LAST_DENSE:
        for m in model.modules():
            if hasattr(m, 'attention_method'):
                m.tree_last_dense_queries = -1

    additional_args = {}
    if not args.no_sample:
        additional_args = dict(
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            top_k=50,
        )

    output = model.generate(
        inputs=inputs.unsqueeze(0).cuda(),
        attention_mask=torch.ones((1, inputs.shape[-1]), dtype=torch.long, device='cuda'),
        max_new_tokens=args.max_tokens,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        logits_processor=LogitsProcessorList([
            StopAfterStringIsGenerated(inputs.shape[-1], tokenizer)
        ]),
        **additional_args,
    )
    output: str = tokenizer.decode(
        output[0][seq_len:].data.cpu(),
        skip_special_tokens=True,
    )
    if output.endswith('</s>'):
        output = output[:-4]
    output = output.strip()
    
    return output


def install_rogue():
    logger = logging.getLogger()

    ROUGE_HOME = os.environ.get('ROUGE_HOME', "cache/ROUGE-1.5.5")
    if "ROUGE_HOME" not in os.environ:
        logger.info("ROUGE_HOME not set, using default location %s", ROUGE_HOME)

    if not os.path.exists(ROUGE_HOME):
        logger.info("ROUGE_HOME=%s not a directory.", ROUGE_HOME)
        try:
            logger.info("Installing rouge Perl script to {ROUGE_HOME} - this will take a few seconds")
            subprocess.run(
                ["curl", "-L", "https://github.com/Yale-LILY/SummEval/tarball/7e4330d", "-o", "project.tar.gz", "-s"])
            subprocess.run(["tar", "-xzf", "project.tar.gz"])
            subprocess.run(["mv", "Yale-LILY-SummEval-7e4330d/evaluation/summ_eval/ROUGE-1.5.5/", ROUGE_HOME])
            subprocess.run(["rm", "project.tar.gz"])
            subprocess.run(["rm", "-rf", "Yale-LILY-SummEval-7e4330d/"])
        except subprocess.CalledProcessError as err:
            logger.error(
                "Failed to install the rouge Perl script; please install manually and set the ROUGE_HOME environment variable.")
            raise err

    return ROUGE_HOME


def generate_samples(args, model, tokenizer, device, out_dir):
    seed()

    dataset = BookSumDataset(
        tokenizer=tokenizer,
        for_eval=True,
        need_tokenization=True,
    )
    train_idx, test_idx = train_test_split(list(range(len(dataset))), test_size=0.05)
    test_dataset = Subset(dataset, test_idx)

    outputs = []

    for idx, item in enumerate(tqdm(test_dataset, dynamic_ncols=True, leave=True, desc="booksum")):
        inputs, completion = item

        output = generate_summary(args, model, tokenizer, device, idx, item, out_dir)

        output_summary = output.replace('\n', '\\n')[:200]
        tqdm.write(f"[{idx:<7}] Summary: {output_summary}[...]")
        with open(out_dir / f"out_{idx}.txt", 'w') as f:
            f.write(output)
        outputs.append(output)

        with open(f"saves/llama_eval/booksum/reference/ref_{idx}.txt", 'w') as f:
            if isinstance(completion, str):
                f.write(completion)
            else:
                f.write(tokenizer.decode(completion, skip_special_tokens=True))


MAX_NEW_TOKENS = 256

def evaluate_rouge(args, model, tokenizer, device, out_dir: pathlib.Path):
    for node in out_dir.glob('*'):
        if node.is_file():
            content = node.read_text()
            ids = tokenizer(content, truncation=True, max_length=256).input_ids
            content = tokenizer.decode(ids, skip_special_tokens=True)
            node.write_text(content)
    
    rouge_dir = install_rogue()

    r = Rouge155(rouge_dir=rouge_dir)
    r.system_dir = out_dir  # "system" is the one we want to measure
    r.model_dir = "saves/llama_eval/booksum/reference"  # "model" is the gold standard (i.e. human summaries)
    r.system_filename_pattern = r'out_(\d+)\.txt'
    r.model_filename_pattern = r'ref_#ID#\.txt'

    output = r.convert_and_evaluate()
    print("R: Recall, P: Precision, F: F1 score")
    print(output)
    output_dict = r.output_to_dict(output)
    with open(out_dir / "rouge_scores.json", 'w') as f:
        json.dump(output_dict, f, indent=2)


@torch.no_grad()
def job_booksum(args, model, tokenizer, device):
    out_dir = pathlib.Path(
        f"saves/llama_eval/booksum/{args.model.replace('/', '_')}_bq{args.hip_block_size_q}"
        f"_bk{args.hip_block_size_k}_k{args.hip_top_k_elems}_gl{args.max_tokens}_ns{args.no_sample}"
    )
    out_dir.mkdir(parents=True, exist_ok=True)
    pathlib.Path("saves/llama_eval/booksum/reference").mkdir(parents=True, exist_ok=True)
    
    generate_samples(args, model, tokenizer, device, out_dir)
    evaluate_rouge(args, model, tokenizer, device, out_dir)
    print(args)


def main():
    parser = ArgumentParser(EvalConfig)
    args = parser.parse_args()
    print(args)

    model, tokenizer = load_model(model_config=args)

    job_booksum(args, model, tokenizer, 'cuda')


if __name__ == '__main__':
    main()
