import hashlib
import json
import logging
from dataclasses import asdict, dataclass
from typing import Optional, Union

import torch
import torch.autograd
import torch.onnx
import torch.utils.checkpoint
import transformers
from peft import (LoraConfig, PeftModel, TaskType, get_peft_model,
                  prepare_model_for_kbit_training)
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM

from hip_attn.v1_3.attention import HiPAttentionArgs


def get_logger():
    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    return logging.getLogger()


@dataclass
class Config:
    # model configuration options
    model: str = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer_id: Optional[str] = None
    hf_token: Optional[str] = None
    use_lora: bool = False
    lora_r: int = 32
    init_from_checkpoint: Optional[str] = None
    disable_hip: bool = False
    use_quantization: bool = False
    dtype: str = "bfloat16"
    save_model: Optional[str] = None

    # training configuration options
    method: str = "hip"
    hip_attn_args: Optional[HiPAttentionArgs] = None
    quantize: bool = False
    dataset: str = "owt"  # or "bs", "rp"
    lr: float = 5e-5
    weight_decay: float = 1e-2
    epochs: int = 100
    batch_size: int = 1
    accumulation_steps: int = 1
    save_steps: int = 100
    save_total_limit: int = 3
    seq_len: int = 32768
    max_steps: int = 10000
    eval_steps: int = 50
    model_checkpoint_dir: str = "./saves/checkpoints"
    checkpoint: Optional[str] = None
    warmup_steps: int = 20
    name: str = "default"
    val_split: float = 0.01
    use_long_ce: bool = False
    long_ce_k: int = 1024
    long_ce_block_size: int = 2048
    long_ce_gamma: float = 5.0
    long_ppl_alpha: float = 2.0
    long_ppl_beta: float = -2.0
    token_dropout_p: float = 0.0
    recompute_n: int = 1024
    disable_hip_tune: bool = False
    pooler_method: str = "dummy"
    pooler_config: Optional[str] = None
    run_name: Optional[str] = None

    def __post_init__(self):
        if self.pooler_config is not None and isinstance(self.pooler_config, str):
            self.pooler_config = json.loads(self.pooler_config)
        elif self.pooler_config is None:
            self.pooler_config = {}

    def get_hash(self):
        data = asdict(self)
        text = json.dumps(data)
        encoder = hashlib.sha256()
        encoder.update(text.encode())
        return encoder.hexdigest()


log = get_logger()
