# -*- coding: utf-8 -*-
# gentraces.py
import ast
import json
import logging
import os
import random
import shutil
import tempfile
from io import StringIO
from pathlib import Path

import datasets
import hydra
import torch
import yaml
from accelerate import Accelerator
from accelerate.utils import gather_object
from hydra.core.hydra_config import HydraConfig
from math_verify import parse, verify
from omegaconf import DictConfig, OmegaConf
from rich.console import Console
from rich.panel import Panel
from rich.pretty import Pretty
from rich.syntax import Syntax
from rich.text import Text
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          DataCollatorWithPadding, LogitsProcessor,
                          LogitsProcessorList, RepetitionPenaltyLogitsProcessor)
from transformers import logging as hf_logging

import wandb
from utils import (ANSWER_FORCE_STRING, SYSTEM_PROMPT, init, load_gsm8k,
                   load_hendrycks_math_dataset)

accelerator = Accelerator()
log = logging.getLogger(__name__)
if not accelerator.is_main_process:
    hf_logging.set_verbosity_error()
    hf_logging.disable_progress_bar()
    datasets.disable_progress_bar()
    tqdm = lambda x, *args, **kwargs: x

def log_color(content, title=""):
    try:
        console = Console()
        console.print(Panel(content, title=title, border_style="cyan", title_align="left"))

        # Log the message as plain text for log files
        string_io = StringIO()
        plain_console = Console(file=string_io, highlight=False)
        plain_console.print(Panel(content, title=title, border_style="none", title_align="left"))
        log.info("\n" + string_io.getvalue())
    except Exception as e:
        # Fallback to plain text logging if Console fails
        log.info(f"Error logging content: {e}")

def is_correct(example, trace_colname):
    trace = example[trace_colname]
    try:
        soln = parse(example["solution"])
        if ANSWER_FORCE_STRING in trace:
            parts = trace.split(ANSWER_FORCE_STRING)
            alt_ans1 = ANSWER_FORCE_STRING.join(parts[:-1])
            alt_ans2 = parts[-1]
            res = any(verify(soln, parse(ans)) for ans in [trace, alt_ans1, alt_ans2])
        else:
            res = verify(soln, parse(trace))
    except:
        print(f"Error parsing trace: {trace} and comparing with solution: {example['solution']}")
        res = False
    return {"is_correct": res}

class CachedModelWrapper:
    def __init__(self, model):
        self.model = model
        self.past_key_values = None
        self.last_position = 0

    def __call__(self, input_ids, attention_mask=None):
        if self.past_key_values is None or input_ids.shape[1] <= self.last_position:
            outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                use_cache=True,
                return_dict=True
            )
            self.past_key_values = outputs.past_key_values
            self.last_position = input_ids.shape[1]
            return outputs.logits

        new_token = input_ids[:, -1:]
        outputs = self.model(
            new_token,
            attention_mask=attention_mask,
            use_cache=True,
            past_key_values=self.past_key_values,
            return_dict=True
        )
        self.past_key_values = outputs.past_key_values
        self.last_position += 1
        return outputs.logits

@hydra.main(config_path=".", config_name="gen_config", version_base="1.3")
def main(cfg: DictConfig):

    cfg.antidistillation = cfg.lam != 0
    cfg.wandb_lam = 1e-8 if cfg.lam == 0 else cfg.lam  # Wandb doesn't allow log scale for 0 values
    if cfg.antidistillation:
        assert cfg.proxy_student is not None, "Proxy student model must be specified for antidistillation"
        assert cfg.grad_path is not None, "Grad path must be specified for antidistillation"

    if cfg.trace_name == "REPLACE_ME":
        raise ValueError("Trace name must be specified")

    init(os.getenv("USER"), cfg.seed)

    if accelerator.is_main_process:
        content = Syntax(OmegaConf.to_yaml(cfg, resolve=True), 'yaml', theme="monokai")
        log_color(content, title="Config")

    tokenizer = AutoTokenizer.from_pretrained(
        cfg.tokenizer,
        use_fast=True,
        fast_tokenizer=True,
        trust_remote_code=True,
        padding_side="left",
    )
    if "llama" in cfg.tokenizer.lower():
        eot_token_id = 128009
        eos_token_id = 128001
        tokenizer.pad_token_id = 128004
        tokenizer.eos_token_id = eos_token_id
        tokenizer.add_eos_token = False
        eos_token = tokenizer.eos_token
    else:
        eos_token = tokenizer.eos_token
        special_tokens = {"pad_token": "[PAD]"}
        tokenizer.add_special_tokens(special_tokens)

    teacher = AutoModelForCausalLM.from_pretrained(
        cfg.teacher,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        use_cache=True,
    ).to(accelerator.device)
    teacher.generation_config.pad_token_id = tokenizer.pad_token_id
    teacher.resize_token_embeddings(len(tokenizer))

    if cfg.antidistillation:
        student = CachedModelWrapper(AutoModelForCausalLM.from_pretrained(
            cfg.proxy_student,
            trust_remote_code=True,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.float16,
            # torch_dtype=torch.float32,
            use_cache=True,
        ).to(accelerator.device))
        dstudent = CachedModelWrapper(AutoModelForCausalLM.from_pretrained(
            cfg.proxy_student,
            trust_remote_code=True,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.float16,
            # torch_dtype=torch.float32,
            use_cache=True,
        ).to(accelerator.device))

        student.model.resize_token_embeddings(len(tokenizer))
        dstudent.model.resize_token_embeddings(len(tokenizer))

        grads = torch.load(cfg.grad_path, map_location='cpu')
        if accelerator.is_main_process:
            log.info(f"Using eps: {cfg.eps}")
        used_grads = set()
        param_sq, grad_sq, num_params = 0, 0, 0
        for name, param in student.model.named_parameters():
            module_name = 'module.' + name
            if module_name in grads:
                # log.info(name, param.data.shape, grads[name].shape)
                grad = grads[module_name].to(param.device, dtype=torch.float32)
                param.data = (param.data.to(torch.float32) + cfg.eps * grad).to(param.data.dtype)
                param_sq += torch.sum(param.data.to(torch.float32) ** 2).item()
                grad_sq += torch.sum(grad ** 2).item()
                num_params += torch.numel(param.data)
                used_grads.add(module_name)
                # did_use_grads = True
        assert used_grads == set(grads.keys()), f"Some gradients were not used or set: {set(grads.keys()) ^ used_grads}"
        if accelerator.is_main_process:
            log_color(f"{param_sq ** 0.5 / num_params ** 0.5:.2e}", title="Param RMSNorm")
            log_color(f"{grad_sq ** 0.5 / num_params ** 0.5:.2e}", title="Grad RMSNorm")

        used_grads = set()
        for name, param in dstudent.model.named_parameters():
            module_name = 'module.' + name
            if module_name in grads:
                # log.info(name, param.data.shape, grads[name].shape)
                grad = grads[module_name].to(param.device, dtype=torch.float32)
                param.data = (param.data.to(torch.float32) - cfg.eps * grad).to(param.data.dtype)
                used_grads.add(module_name)
                # did_use_grads = True
        assert used_grads == set(grads.keys()), f"Some gradients were not used or set: {set(grads.keys()) ^ used_grads}"
        del grads
        if accelerator.is_main_process:
            log.info('Calculated grads')

    if "gsm8k" in cfg.data_split:
        dataset = load_gsm8k(split=cfg.data_split.split("_")[1])
    elif "math" in cfg.data_split:
        dataset = load_hendrycks_math_dataset(split=cfg.data_split.split("_")[1])
    else:
        raise ValueError(f"Unknown dataset and split: {cfg.data_split}")

    if cfg.max_samples is not None:
        dataset = dataset.take(min(cfg.max_samples, len(dataset)))

    if accelerator.is_main_process:
        def preprocess_function(examples):
            messages = [[{"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": problem.strip() + "\n"}] for problem in examples["problem"]]
            tokens = [toks for toks in tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
            seq_lengths = [len(toks) for toks in tokens]
            return {"input_ids": tokens, "seq_lengths": seq_lengths}

        proc_dataset = dataset.map(
            preprocess_function,
            batched=True,
            num_proc=96,
            # remove_columns=dataset_shard.column_names,
            desc="Preprocessing dataset",
            load_from_cache_file=True,
        )
        proc_dataset = proc_dataset.filter(lambda x: x["seq_lengths"] <= cfg.max_prompt_length)
        log_color(tokenizer.decode(proc_dataset[0]['input_ids']), title="Example Input")
        seq_length_stats = proc_dataset.to_pandas()["seq_lengths"].describe()
        log_color(str(seq_length_stats.round(2)), title="Sequence Lengths")
        proc_dataset = proc_dataset.remove_columns("seq_lengths")
        proc_dataset.save_to_disk("/tmp/cached_proc_dataset")
    accelerator.wait_for_everyone()
    proc_dataset = datasets.load_from_disk("/tmp/cached_proc_dataset")

    num_shards = accelerator.num_processes
    shard_id = accelerator.process_index
    dataset_shard = proc_dataset.shard(num_shards=num_shards, index=shard_id)
    ptds_shard = dataset_shard.remove_columns(dataset.column_names)

    dataloader = DataLoader(
        ptds_shard,
        batch_size=cfg.batch_size,
        shuffle=False,
        collate_fn=DataCollatorWithPadding(tokenizer=tokenizer)
    )

    class LogprobsModifier(LogitsProcessor):
        def __init__(self, lam, eps, attention_mask, repetition_penalty):
            super().__init__()
            self.lam = lam
            self.eps = eps
            self.attention_mask = attention_mask
            self.repetition_penalty = repetition_penalty

        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            attention_mask = F.pad(self.attention_mask, pad=(0, input_ids.shape[1]-self.attention_mask.shape[1]), value=1)
            out_target = student(input_ids=input_ids, attention_mask=attention_mask)[:, -1]
            out_Dtarget = dstudent(input_ids=input_ids, attention_mask=attention_mask)[:, -1]
            # out_target = torch.log_softmax(out_target.float(), dim=-1)
            # out_Dtarget = torch.log_softmax(out_Dtarget.float(), dim=-1)
            ad_term = (self.lam / (2*self.eps)) * (out_target.float() - out_Dtarget.float())

            # logprobs = torch.log_softmax(scores.float(), dim=-1)

            # ad_term = out_target.float() - out_Dtarget.float()
            # white_ad_term = (ad_term - torch.mean(ad_term, dim=-1, keepdim=True)) / (torch.std(ad_term, dim=-1, keepdim=True) + 1e-6)
            # multiplier = torch.exp((self.lam/(2*self.eps)) * white_ad_term)
            # scores = scores.float()
            # scores = torch.where(scores > 0, scores * multiplier, scores / multiplier)
            scores = scores.float() + ad_term
            if cfg.repetition_penalty != 1.0:
                score = torch.gather(scores, 1, input_ids)

                # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
                score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)

                scores_processed = scores.scatter(1, input_ids, score)
            else:
                scores_processed = scores
            return scores_processed

    traces = []
    for batch in tqdm(dataloader, total=len(dataloader), desc=f"tau={cfg.tau:.2e}, lam={cfg.lam:.2e}, eps={cfg.eps:.2e}"):
        batch = {key: value.to(accelerator.device) for key, value in batch.items()}
        with torch.inference_mode():
            kwargs = {
                "temperature": cfg.tau if cfg.tau > 0 else None,
                "do_sample": True if cfg.tau > 0 else False,
                "top_p": 0.95 if cfg.tau > 0 else None,
                "logits_processor": (
                    LogitsProcessorList([LogprobsModifier(cfg.lam, cfg.eps, batch["attention_mask"], cfg.repetition_penalty)])
                    if cfg.antidistillation else None
                ),
                "renormalize_logits": True if cfg.antidistillation else False,
                "use_cache": True,
                "eos_token_id": tokenizer.eos_token_id,
                "repetition_penalty": 1.0 if cfg.antidistillation else cfg.repetition_penalty,
            }
            outputs = teacher.generate(
                **batch,
                max_length=cfg.max_length,
                **kwargs,
            )
            generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
            for text in generated_texts:
                text = text.replace(tokenizer.pad_token, "")
                traces.append(text)

    if accelerator.is_main_process:
        log_color(traces[0], title="First trace")
    dataset_shard = dataset_shard.add_column(cfg.trace_colname, traces)
    if cfg.antidistillation:
        del student, dstudent
        torch.cuda.empty_cache()

    dataset_shard = dataset_shard.map(
        is_correct,
        fn_kwargs={"trace_colname": cfg.trace_colname},
        desc="Checking raw correctness"
    )
    dataset_shard = dataset_shard.rename_columns({"is_correct": "is_raw_correct"})

    if 'llama' in cfg.teacher.lower():
        response_string = "<|start_header_id|>assistant<|end_header_id|>\n\n"
    elif "r1" in cfg.teacher.lower():
        response_string = "<｜Assistant｜>"
    elif "qwen" in cfg.student.lower():
        response_string = "<|im_start|>assistant\n"
    else:
        raise ValueError(f"Unknown model {cfg.teacher}")
    if cfg.answer_force:
        traces_ = []
        for text in traces:
            if "</think>" in text.split(response_string)[-1]:
                traces_.append(text + ANSWER_FORCE_STRING)
            else:
                traces_.append(text + "\n</think>" + ANSWER_FORCE_STRING)
        af_batch_size = cfg.batch_size // 2
        traces_batched = [traces_[i:i+af_batch_size] for i in range(0, len(traces_), af_batch_size)]
        traces_af = []
        for batch in tqdm(traces_batched, total=len(traces_batched)):
            batch = [text.replace(tokenizer.bos_token, '', 1).replace(eos_token, '') for text in batch]
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to(accelerator.device)
            with torch.inference_mode():
                outputs = teacher.generate(
                    **inputs,
                    do_sample=False,
                    temperature=None,
                    top_p=None,
                    logits_processor=None,
                    renormalize_logits=False,
                    max_new_tokens=32,
                    use_cache=True,
                    eos_token_id=tokenizer.eos_token_id
                )
                generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
                for text in generated_texts:
                    text = text.replace(tokenizer.pad_token, "")
                    if not text.endswith(tokenizer.eos_token):
                        text = text + tokenizer.eos_token
                    traces_af.append(text)

        if accelerator.is_main_process:
            log_color(traces_af[0], title="First af trace")
        del teacher
        torch.cuda.empty_cache()

        dataset_shard = dataset_shard.add_column(cfg.trace_colname+"_af", traces_af)
        dataset_shard = dataset_shard.map(
            is_correct,
            fn_kwargs={"trace_colname": cfg.trace_colname+"_af"},
            desc="Checking af correctness"
        )
        dataset_shard = dataset_shard.rename_columns({"is_correct": "is_af_correct"})

    tmp_dir = Path(tempfile.mkdtemp(prefix="tmp_ds_"))
    shard_path = tmp_dir / f"shard_rank_{accelerator.process_index:05d}"
    dataset_shard.save_to_disk(shard_path)

    accelerator.wait_for_everyone()

    all_paths = gather_object([shard_path])
    if accelerator.is_main_process:
        trace_dataset = datasets.concatenate_datasets([datasets.load_from_disk(path) for path in all_paths])

        trace_dataset.save_to_disk(cfg.trace_path)
        trace_dataset.to_parquet(cfg.trace_path + ".parquet")

        for path in all_paths:
            shutil.rmtree(path, ignore_errors=True)

        example_row = trace_dataset[random.randint(0, len(trace_dataset)-1)]
        log_color(example_row["problem"], title="Example Problem")
        log_color(example_row["solution"], title="Example Solution")
        log_color(example_row[cfg.trace_colname], title=f"Example Trace [tau={cfg.tau:.2e}, lam={cfg.lam:.2e}, eps={cfg.eps:.2e}]")
        if cfg.answer_force:
            log_color(example_row[cfg.trace_colname + "_af"], title=f"Example AF Trace [tau={cfg.tau:.2e}, lam={cfg.lam:.2e}, eps={cfg.eps:.2e}]")

        trace_df = trace_dataset.to_pandas()
        trace_len_stats = {k:float(v) for k,v in trace_df[cfg.trace_colname].map(lambda x: len(tokenizer.encode(x))).describe().items()}
        raw_accuracy = float(trace_df["is_raw_correct"].mean())
        af_accuracy = float(trace_df["is_af_correct"].mean())

        full_cfg = OmegaConf.to_container(cfg, resolve=True)
        hydra_cfg = HydraConfig.get()
        full_cfg["hydra"] = {
            "run_dir": hydra_cfg.run.dir,
            "job_name": hydra_cfg.job.name,
            "cwd": hydra_cfg.runtime.cwd,
        }
        full_cfg["stats"] = {
            "raw_accuracy": raw_accuracy,
            "af_accuracy": af_accuracy,
            "trace_len_stats": trace_len_stats,
        }
        yaml_path = cfg.trace_path + ".yaml"
        with open(yaml_path, "w") as f:
            OmegaConf.save(full_cfg, f)
        log.info(f"Configuration saved to {yaml_path}")

        def flatten_dict(d, parent_key='', sep='.'):
            items = []
            for k, v in d.items():
                new_key = f"{parent_key}{sep}{k}" if parent_key else k
                if isinstance(v, dict):
                    items.extend(flatten_dict(v, new_key, sep=sep).items())
                else:
                    items.append((new_key, v))
            return dict(items)

        with open(cfg.trace_registry, "a") as f:  # jsonl
            f.write(json.dumps(flatten_dict(full_cfg)) + "\n")
        log.info(f"Metadata appended to {cfg.trace_registry}")

        content = Syntax(OmegaConf.to_yaml(full_cfg, resolve=True), 'yaml', theme="monokai")
        log_color(content, title="Final Config")

        if cfg.use_wandb and cfg.teacher_cfg:
            with open(cfg.teacher_cfg, "r") as f:
                teacher_cfg = yaml.safe_load(f)
            wandb_run_id = teacher_cfg.get("wandb_run_id")
            if wandb_run_id is None:
                raise ValueError("wandb is true but wandb_run_id not found in teacher config")
            wandb.init(
                project="antidistillation",
                id=wandb_run_id,
                resume="allow",
            )
            wandb.log({
                "eval_raw_accuracy": raw_accuracy,
                "eval_af_accuracy": af_accuracy,
            })
    accelerator.wait_for_everyone()
    accelerator.end_training()

if __name__ == "__main__":
    main()
