# -*- coding: utf-8 -*-
# save_grad.py
import argparse
import os
import sys

import datasets
import torch
import yaml
from accelerate import Accelerator
from datasets import load_from_disk
from rich import print as rprint
from rich.console import Console
from rich.panel import Panel
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import logging as hf_logging
from trl import DataCollatorForCompletionOnlyLM

from utils import init

accelerator = Accelerator()
if not accelerator.is_main_process:
    hf_logging.set_verbosity_error()
    hf_logging.disable_progress_bar()
    datasets.disable_progress_bar()
    tqdm = lambda x, *args, **kwargs: x

def log_color(content, title=""):
    console = Console(highlight=True, file=sys.stdout)
    console.print(Panel(content, title=title, border_style="cyan", title_align="left"))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate Hendrycks Math CoT traces.")
    parser.add_argument("holdout_config", type=str, help="Path to the holdout config.yaml file")
    parser.add_argument("--proxy_student", type=str, help="Proxy student model to use")
    parser.add_argument("--tokenizer", type=str, help="Tokenizer model to use")
    parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
    parser.add_argument("--trace_colname", type=str, help="Column name for CoT traces")
    # parser.add_argument("--holdout_traces", type=str, required=True, help="Path to the holdout traces dataset")
    # parser.add_argument("--output", type=str, required=True, help="Output file for the gradients")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training")
    args = parser.parse_args()

    # Load the config file
    with open(args.holdout_config, 'r') as f:
        config = yaml.safe_load(f)
    for key, value in config.items():
        if not hasattr(args, key) or getattr(args, key) is None:
            setattr(args, key, value)

    init(os.getenv("USER"), args.seed)

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True, padding_side="left")
    if tokenizer.pad_token != "[PAD]":
        special_tokens = {"pad_token": "[PAD]"}
        tokenizer.add_special_tokens(special_tokens)

    model = AutoModelForCausalLM.from_pretrained(
        args.proxy_student,
        trust_remote_code=True,
        # attn_implementation="flash_attention_2",
        torch_dtype=torch.float32,
        use_cache=True,
        # device_map={'':device_string}
    )
    model.resize_token_embeddings(len(tokenizer))
    if accelerator.is_main_process:
        print(f"Student model {args.proxy_student} loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M parameters")

    if accelerator.is_main_process:
        ds = load_from_disk(args.trace_path)

        def preprocessor(examples):
            if tokenizer.bos_token:
                traces = [text.replace(tokenizer.bos_token, "", 1) for text in examples[args.trace_colname]]
            tokenized = tokenizer(
                traces,
                padding=False,
                truncation=True,
                return_attention_mask=True,
                return_tensors=None,
            )
            return tokenized
        input_ds = ds.map(
            preprocessor,
            batched=True,
            num_proc=96,
            remove_columns=ds.column_names,
            desc="Preprocessing dataset",
            load_from_cache_file=True
        )
        dataset_size = len(input_ds)
        input_ds.save_to_disk("/tmp/cached_ds")
        print(f"Loaded {args.trace_path} with {dataset_size} samples")
        print(f"Example trace: {tokenizer.decode(input_ds[0]['input_ids'])}")
    accelerator.wait_for_everyone()
    input_ds = load_from_disk("/tmp/cached_ds")

    response_str = "<｜Assistant｜>"
    data_collator = DataCollatorForCompletionOnlyLM(
        response_template=tokenizer.encode(response_str, add_special_tokens=False),
        tokenizer=tokenizer,
        mlm=False
    )
    dataloader = DataLoader(
        input_ds,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=data_collator,
        num_workers=1,
        pin_memory=True
    )

    # Prepare model and dataloader with accelerator
    model, dataloader = accelerator.prepare(model, dataloader)

    # Initialize gradient accumulators
    grads = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            # Create on the same device as param
            grads[name] = torch.zeros_like(param.data)

    local_samples = 0
    model.train()
    for batch in tqdm(dataloader, desc="Accumulating gradients", disable=not accelerator.is_main_process):
        local_samples += batch["input_ids"].size(0)
        outputs = model(**batch)
        loss = outputs.loss * batch["input_ids"].size(0)
        accelerator.backward(loss)
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                grads[name].add_(param.grad)
        model.zero_grad()

    local_tensor = torch.tensor([local_samples], device=accelerator.device)
    accelerator.wait_for_everyone()
    reduced_tensor = accelerator.reduce(local_tensor, reduction="sum")
    total_samples = reduced_tensor.item()
    if accelerator.is_main_process:
        print(f"Processed a total of {total_samples} samples across all processes")

    for name in grads:
        accelerator.reduce(grads[name], reduction="sum")
        if accelerator.is_main_process:
            grads[name] = grads[name] / total_samples

    # Save the gradients on the main process
    if accelerator.is_main_process:
        torch.save(grads, os.path.join(args.exp_dir, "student_grads.pt"))
        print(f"Saved average gradients to {os.path.join(args.exp_dir, 'student_grads.pt')}")

    accelerator.end_training()
