# -*- coding: utf-8 -*-
# gentraces.py
import ast
import json
import logging
import os
import random
import shutil
import tempfile
from io import StringIO
from pathlib import Path

import datasets
import hydra
import torch
import yaml
from accelerate import Accelerator
from accelerate.utils import gather_object
from hydra.core.hydra_config import HydraConfig
from math_verify import parse, verify
from omegaconf import DictConfig, OmegaConf
from rich.console import Console
from rich.panel import Panel
from rich.pretty import Pretty
from rich.syntax import Syntax
from rich.text import Text
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          DataCollatorWithPadding, LogitsProcessor,
                          LogitsProcessorList, RepetitionPenaltyLogitsProcessor)
from transformers import logging as hf_logging

import wandb
from utils import (ANSWER_FORCE_STRING, SYSTEM_PROMPT, init, load_gsm8k,
                   load_hendrycks_math_dataset)
from jvp_utils import stack_kv_cache, unstack_kv_tensor, extract_and_stack_last_token_kv, concat_stacked_kv_caches, make_logp_fn
from functorch import make_functional_with_buffers
from torch.func import jvp
accelerator = Accelerator()
log = logging.getLogger(__name__)
if not accelerator.is_main_process:
    hf_logging.set_verbosity_error()
    hf_logging.disable_progress_bar()
    datasets.disable_progress_bar()
    tqdm = lambda x, *args, **kwargs: x

def log_color(content, title=""):
    try:
        console = Console()
        console.print(Panel(content, title=title, border_style="cyan", title_align="left"))

        # Log the message as plain text for log files
        string_io = StringIO()
        plain_console = Console(file=string_io, highlight=False)
        plain_console.print(Panel(content, title=title, border_style="none", title_align="left"))
        log.info("\n" + string_io.getvalue())
    except Exception as e:
        # Fallback to plain text logging if Console fails
        log.info(f"Error logging content: {e}")

def is_correct(example, trace_colname):
    trace = example[trace_colname]
    try:
        soln = parse(example["solution"])
        if ANSWER_FORCE_STRING in trace:
            parts = trace.split(ANSWER_FORCE_STRING)
            alt_ans1 = ANSWER_FORCE_STRING.join(parts[:-1])
            alt_ans2 = parts[-1]
            res = any(verify(soln, parse(ans)) for ans in [trace, alt_ans1, alt_ans2])
        else:
            res = verify(soln, parse(trace))
    except:
        print(f"Error parsing trace: {trace} and comparing with solution: {example['solution']}")
        res = False
    return {"is_correct": res}

class CachedModelWrapper:
    def __init__(self, model):
        self.model = model
        self.past_key_values = None
        self.last_position = 0

    def __call__(self, input_ids, attention_mask=None):
        if self.past_key_values is None or input_ids.shape[1] <= self.last_position:
            outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                use_cache=True,
                return_dict=True
            )
            self.past_key_values = outputs.past_key_values
            self.last_position = input_ids.shape[1]
            return outputs.logits

        new_token = input_ids[:, -1:]
        outputs = self.model(
            new_token,
            attention_mask=attention_mask,
            use_cache=True,
            past_key_values=self.past_key_values,
            return_dict=True
        )
        self.past_key_values = outputs.past_key_values
        self.last_position += 1
        return outputs.logits

@hydra.main(config_path=".", config_name="gen_config", version_base="1.3")
def main(cfg: DictConfig):

    cfg.antidistillation = cfg.lam != 0
    cfg.wandb_lam = 1e-8 if cfg.lam == 0 else cfg.lam  # Wandb doesn't allow log scale for 0 values
    if cfg.antidistillation:
        assert cfg.proxy_student is not None, "Proxy student model must be specified for antidistillation"
        assert cfg.grad_path is not None, "Grad path must be specified for antidistillation"

    if cfg.trace_name == "REPLACE_ME":
        raise ValueError("Trace name must be specified")

    init(os.getenv("USER"), cfg.seed)

    if accelerator.is_main_process:
        content = Syntax(OmegaConf.to_yaml(cfg, resolve=True), 'yaml', theme="monokai")
        log_color(content, title="Config")

    tokenizer = AutoTokenizer.from_pretrained(
        cfg.tokenizer,
        use_fast=True,
        fast_tokenizer=True,
        trust_remote_code=True,
        padding_side="left",
    )
    if "llama" in cfg.tokenizer.lower():
        eot_token_id = 128009
        eos_token_id = 128001
        tokenizer.pad_token_id = 128004
        tokenizer.eos_token_id = eos_token_id
        tokenizer.add_eos_token = False
        eos_token = tokenizer.eos_token
    else:
        eos_token = tokenizer.eos_token
        special_tokens = {"pad_token": "[PAD]"}
        tokenizer.add_special_tokens(special_tokens)

    teacher = AutoModelForCausalLM.from_pretrained(
        cfg.teacher,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        use_cache=True,
    ).to(accelerator.device)
    teacher.generation_config.pad_token_id = tokenizer.pad_token_id
    teacher.resize_token_embeddings(len(tokenizer))

    if cfg.antidistillation:
        if not cfg.use_jvp: # use finite difference method
            student = CachedModelWrapper(AutoModelForCausalLM.from_pretrained(
                cfg.proxy_student,
                trust_remote_code=True,
                attn_implementation="flash_attention_2",
                torch_dtype=torch.float16,
                # torch_dtype=torch.float32,
                use_cache=True,
            ).to(accelerator.device))
            dstudent = CachedModelWrapper(AutoModelForCausalLM.from_pretrained(
                cfg.proxy_student,
                trust_remote_code=True,
                attn_implementation="flash_attention_2",
                torch_dtype=torch.float16,
                # torch_dtype=torch.float32,
                use_cache=True,
            ).to(accelerator.device))

            student.model.resize_token_embeddings(len(tokenizer))
            dstudent.model.resize_token_embeddings(len(tokenizer))

            grads = torch.load(cfg.grad_path, map_location='cpu')
            if accelerator.is_main_process:
                log.info(f"Using eps: {cfg.eps}")
            used_grads = set()
            param_sq, grad_sq, num_params = 0, 0, 0
            for name, param in student.model.named_parameters():
                module_name = 'module.' + name
                if module_name in grads:
                    # log.info(name, param.data.shape, grads[name].shape)
                    grad = grads[module_name].to(param.device, dtype=torch.float32)
                    param.data = (param.data.to(torch.float32) + cfg.eps * grad).to(param.data.dtype)
                    param_sq += torch.sum(param.data.to(torch.float32) ** 2).item()
                    grad_sq += torch.sum(grad ** 2).item()
                    num_params += torch.numel(param.data)
                    used_grads.add(module_name)
                    # did_use_grads = True
            assert used_grads == set(grads.keys()), f"Some gradients were not used or set: {set(grads.keys()) ^ used_grads}"
            if accelerator.is_main_process:
                log_color(f"{param_sq ** 0.5 / num_params ** 0.5:.2e}", title="Param RMSNorm")
                log_color(f"{grad_sq ** 0.5 / num_params ** 0.5:.2e}", title="Grad RMSNorm")

            used_grads = set()
            for name, param in dstudent.model.named_parameters():
                module_name = 'module.' + name
                if module_name in grads:
                    # log.info(name, param.data.shape, grads[name].shape)
                    grad = grads[module_name].to(param.device, dtype=torch.float32)
                    param.data = (param.data.to(torch.float32) - cfg.eps * grad).to(param.data.dtype)
                    used_grads.add(module_name)
                    # did_use_grads = True
            assert used_grads == set(grads.keys()), f"Some gradients were not used or set: {set(grads.keys()) ^ used_grads}"
            del grads
            if accelerator.is_main_process:
                log.info('Calculated grads')
        else:
            if accelerator.is_main_process:
                log.info("Using JVP method for antidistillation")
            student = CachedModelWrapper(AutoModelForCausalLM.from_pretrained(
                cfg.proxy_student,
                trust_remote_code=True,
                torch_dtype=torch.float32, # We use float32 for JVP, but we need to turn off flash_attention or sdpa for forward mode AD
                use_cache=True,
            ).to(accelerator.device))

            student.model.resize_token_embeddings(len(tokenizer))


            f_student, student_params, student_buffers = make_functional_with_buffers(student.model)

            student_grads = torch.load(cfg.grad_path, map_location='cpu')
            # move grads to the same device as the model
            if accelerator.is_main_process:
                log.info(f"Using eps: {cfg.eps}")
            used_grads = set()
            for name, param in student.model.named_parameters():
                module_name = 'module.' + name
                if module_name in student_grads:
                    # log.info(name, param.data.shape, grads[name].shape)
                    student_grads[module_name] = student_grads[module_name].to(param.device, dtype=torch.float)
                    used_grads.add(module_name)
                    # did_use_grads = True
            assert used_grads == set(student_grads.keys()), f"Some gradients were not used or set: {set(student_grads.keys()) ^ used_grads}"
            student_grads = tuple(student_grads.values())
            del student.model
            torch.cuda.empty_cache()
            if accelerator.is_main_process:
                log.info('Calculated grads')

    if "gsm8k" in cfg.data_split:
        dataset = load_gsm8k(split=cfg.data_split.split("_")[1])
    elif "math" in cfg.data_split:
        dataset = load_hendrycks_math_dataset(split=cfg.data_split.split("_")[1])
    else:
        raise ValueError(f"Unknown dataset and split: {cfg.data_split}")

    if cfg.max_samples is not None:
        dataset = dataset.take(cfg.max_samples)

    num_shards = accelerator.num_processes
    shard_id = accelerator.process_index
    dataset_shard = dataset.shard(num_shards=num_shards, index=shard_id)

    def preprocess_function(examples):
        messages = [[{"role": "system", "content": SYSTEM_PROMPT},
                     {"role": "user", "content": problem.strip() + "\n"}] for problem in examples["problem"]]
        tokens = [toks for toks in tokenizer.apply_chat_template(messages, add_generation_prompt=True) if len(toks) <= cfg.max_length]
        max_length = max([len(toks) for toks in tokens])
        return {"input_ids": tokens, "max_length": [max_length] * len(tokens)}

    ptds_shard = dataset_shard.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset_shard.column_names,
        desc="Preprocessing dataset",
        load_from_cache_file=True,
    )

    if accelerator.is_main_process:
        max_train_token_length = max(ptds_shard["max_length"])
        log_color(f"Max train token length: {max_train_token_length}")
        log_color(tokenizer.decode(ptds_shard[0]["input_ids"]), title="Example prompt")
    ptds_shard = ptds_shard.remove_columns("max_length")

    dataloader = DataLoader(
        ptds_shard,
        batch_size=cfg.batch_size,
        shuffle=False,
        collate_fn=DataCollatorWithPadding(tokenizer=tokenizer)
    )

    class LogprobsModifier(LogitsProcessor):
        def __init__(self, lam, eps, attention_mask, repetition_penalty):
            super().__init__()
            self.lam = lam
            self.eps = eps
            self.attention_mask = attention_mask
            self.repetition_penalty = repetition_penalty
            self.past_key_values = None # past kv cache if we use JVP
            self.d_past_key_values = None # the differential of past kv cache if we use JVP

        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            attention_mask = F.pad(self.attention_mask, pad=(0, input_ids.shape[1]-self.attention_mask.shape[1]), value=1)
            if not cfg.use_jvp:
                out_target = student(input_ids=input_ids, attention_mask=attention_mask)[:, -1]
                out_Dtarget = dstudent(input_ids=input_ids, attention_mask=attention_mask)[:, -1]
                # out_target = torch.log_softmax(out_target.float(), dim=-1)
                # out_Dtarget = torch.log_softmax(out_Dtarget.float(), dim=-1)
                ad_term = (self.lam / (2*self.eps)) * (out_target.float() - out_Dtarget.float())

                # logprobs = torch.log_softmax(scores.float(), dim=-1)

                # ad_term = out_target.float() - out_Dtarget.float()
                # white_ad_term = (ad_term - torch.mean(ad_term, dim=-1, keepdim=True)) / (torch.std(ad_term, dim=-1, keepdim=True) + 1e-6)
                # multiplier = torch.exp((self.lam/(2*self.eps)) * white_ad_term)
                # scores = scores.float()
                # scores = torch.where(scores > 0, scores * multiplier, scores / multiplier)
                scores = scores.float() + ad_term
                score = torch.gather(scores, 1, input_ids)

                # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
                score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)

                scores_processed = scores.scatter(1, input_ids, score)
                return scores_processed
                # return torch.log_softmax(scores, dim=-1)
                return scores
            else:
                use_input_ids = input_ids
                if self.past_key_values is not None:
                    use_input_ids = input_ids[:, -1:]
                logp_fn = make_logp_fn(f_student, student_buffers, use_input_ids, attention_mask)
                with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
                    if self.past_key_values is None:
                        output, jvp_output = jvp(logp_fn, (student_params,), (student_grads,)) # has shape (batch_size, vocab_size)
                    else:
                        output, jvp_output = jvp(logp_fn, (student_params, self.past_key_values), (student_grads, self.d_past_key_values)) # has shape (batch_size, vocab_size)
                # keep in mind that output is a tuple of (log_probs, last_key_values)
                self.past_key_values = concat_stacked_kv_caches(self.past_key_values, output[1])
                self.d_past_key_values = concat_stacked_kv_caches(self.d_past_key_values, jvp_output[1])
                ad_term = self.lam * jvp_output[0][:, -1]
                return scores.float() + ad_term
    traces = []
    for batch in tqdm(dataloader, total=len(dataloader), desc=f"tau={cfg.tau:.2e}, lam={cfg.lam:.2e}, eps={cfg.eps:.2e}"):
        batch = {key: value.to(accelerator.device) for key, value in batch.items()}
        # with torch.inference_mode():
        outputs = teacher.generate(
            **batch,
            max_length=cfg.max_length,
            temperature=cfg.tau if cfg.tau > 0 else None,
            do_sample=True if cfg.tau > 0 else False,
            top_p=0.95 if cfg.tau > 0 else None,
            logits_processor=(
                LogitsProcessorList([LogprobsModifier(cfg.lam, cfg.eps, batch["attention_mask"], cfg.repetition_penalty)])
                if cfg.antidistillation else None
            ),
            renormalize_logits=True if cfg.antidistillation else False,
            use_cache=True,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.0 if cfg.antidistillation else cfg.repetition_penalty,
        )
        generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
        for text in generated_texts:
            text = text.replace(tokenizer.pad_token, "")
            traces.append(text)

    if accelerator.is_main_process:
        log_color(traces[0], title="First trace")
    dataset_shard = dataset_shard.add_column(cfg.trace_colname, traces)
    if cfg.antidistillation:
        if cfg.use_jvp:
            del student, f_student, student_params, student_buffers, student_grads
        else:
            del student, dstudent
        torch.cuda.empty_cache()

    dataset_shard = dataset_shard.map(
        is_correct,
        fn_kwargs={"trace_colname": cfg.trace_colname},
        desc="Checking raw correctness"
    )
    dataset_shard = dataset_shard.rename_columns({"is_correct": "is_raw_correct"})

    if 'llama' in cfg.teacher.lower():
        response_string = "<|start_header_id|>assistant<|end_header_id|>\n\n"
    elif "r1" in cfg.teacher.lower():
        response_string = "<｜Assistant｜>"
    elif "qwen" in cfg.student.lower():
        response_string = "<|im_start|>assistant\n"
    else:
        raise ValueError(f"Unknown model {cfg.teacher}")
    if cfg.answer_force:
        traces_ = []
        for text in traces:
            if "</think>" in text.split(response_string)[-1]:
                traces_.append(text + ANSWER_FORCE_STRING)
            else:
                traces_.append(text + "\n</think>" + ANSWER_FORCE_STRING)
        af_batch_size = cfg.batch_size // 2
        traces_batched = [traces_[i:i+af_batch_size] for i in range(0, len(traces_), af_batch_size)]
        traces_af = []
        for batch in tqdm(traces_batched, total=len(traces_batched)):
            batch = [text.replace(tokenizer.bos_token, '', 1).replace(eos_token, '') for text in batch]
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to(accelerator.device)
            # with torch.inference_mode():
            outputs = teacher.generate(
                **inputs,
                do_sample=False,
                temperature=None,
                top_p=None,
                logits_processor=None,
                renormalize_logits=False,
                max_new_tokens=32,
                use_cache=True,
                eos_token_id=tokenizer.eos_token_id
            )
            generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
            for text in generated_texts:
                text = text.replace(tokenizer.pad_token, "")
                if not text.endswith(tokenizer.eos_token):
                    text = text + tokenizer.eos_token
                traces_af.append(text)

        if accelerator.is_main_process:
            log_color(traces_af[0], title="First af trace")
        del teacher
        torch.cuda.empty_cache()

        dataset_shard = dataset_shard.add_column(cfg.trace_colname+"_af", traces_af)
        dataset_shard = dataset_shard.map(
            is_correct,
            fn_kwargs={"trace_colname": cfg.trace_colname+"_af"},
            desc="Checking af correctness"
        )
        dataset_shard = dataset_shard.rename_columns({"is_correct": "is_af_correct"})

    tmp_dir = Path(tempfile.mkdtemp(prefix="tmp_ds_"))
    shard_path = tmp_dir / f"shard_rank_{accelerator.process_index:05d}"
    dataset_shard.save_to_disk(shard_path)

    accelerator.wait_for_everyone()

    all_paths = gather_object([shard_path])
    if accelerator.is_main_process:
        trace_dataset = datasets.concatenate_datasets([datasets.load_from_disk(path) for path in all_paths])

        trace_dataset.save_to_disk(cfg.trace_path)
        trace_dataset.to_parquet(cfg.trace_path + ".parquet")

        for path in all_paths:
            shutil.rmtree(path, ignore_errors=True)

        example_row = trace_dataset[random.randint(0, len(trace_dataset)-1)]
        log_color(example_row["problem"], title="Example Problem")
        log_color(example_row["solution"], title="Example Solution")
        log_color(example_row[cfg.trace_colname], title=f"Example Trace [tau={cfg.tau:.2e}, lam={cfg.lam:.2e}, eps={cfg.eps:.2e}]")
        if cfg.answer_force:
            log_color(example_row[cfg.trace_colname + "_af"], title=f"Example AF Trace [tau={cfg.tau:.2e}, lam={cfg.lam:.2e}, eps={cfg.eps:.2e}]")

        trace_df = trace_dataset.to_pandas()
        trace_len_stats = trace_df[cfg.trace_colname].str.len().describe().to_dict()
        raw_accuracy = float(trace_df["is_raw_correct"].mean())
        af_accuracy = float(trace_df["is_af_correct"].mean())

        full_cfg = OmegaConf.to_container(cfg, resolve=True)
        hydra_cfg = HydraConfig.get()
        full_cfg["hydra"] = {
            "run_dir": hydra_cfg.run.dir,
            "job_name": hydra_cfg.job.name,
            "cwd": hydra_cfg.runtime.cwd,
        }
        full_cfg["stats"] = {
            "raw_accuracy": raw_accuracy,
            "af_accuracy": af_accuracy,
            "trace_len_stats": trace_len_stats,
        }
        yaml_path = cfg.trace_path + ".yaml"
        with open(yaml_path, "w") as f:
            OmegaConf.save(full_cfg, f)
        log.info(f"Configuration saved to {yaml_path}")

        def flatten_dict(d, parent_key='', sep='.'):
            items = []
            for k, v in d.items():
                new_key = f"{parent_key}{sep}{k}" if parent_key else k
                if isinstance(v, dict):
                    items.extend(flatten_dict(v, new_key, sep=sep).items())
                else:
                    items.append((new_key, v))
            return dict(items)

        with open(cfg.trace_registry, "a") as f:  # jsonl
            f.write(json.dumps(flatten_dict(full_cfg)) + "\n")
        log.info(f"Metadata appended to {cfg.trace_registry}")

        content = Syntax(OmegaConf.to_yaml(full_cfg, resolve=True), 'yaml', theme="monokai")
        log_color(content, title="Final Config")

        if cfg.use_wandb and cfg.teacher_cfg:
            with open(cfg.teacher_cfg, "r") as f:
                teacher_cfg = yaml.safe_load(f)
            wandb_run_id = teacher_cfg.get("wandb_run_id")
            if wandb_run_id is None:
                raise ValueError("wandb is true but wandb_run_id not found in teacher config")
            wandb.init(
                project="antidistillation",
                id=wandb_run_id,
                resume="allow",
            )
            wandb.log({
                "eval_raw_accuracy": raw_accuracy,
                "eval_af_accuracy": af_accuracy,
            })
    accelerator.end_training()

if __name__ == "__main__":
    main()
