import json
import torch
import logging
from datetime import datetime
from pathlib import Path
from typing import Iterable, Optional, TYPE_CHECKING
from copy import deepcopy
import random
import numpy as np

from src.run.dataloader import DataLoader
from src.run.distributed import is_main_process, is_distributed, get_rank, get_raw_model, DDP
from src.model.config import ModelConfig, Transformer
from src.model.demix import DemixTransformer

if TYPE_CHECKING:
    from src.run.config import RunConfig

# --------------------------------------------------------------------------- #
# utils                                                                       #
# --------------------------------------------------------------------------- #

def convert_ordered_args(
    alpha: float | None = None,
    beta: float | None = None,
    aux_prc: float | None = None,
    core_prc: float | None = None,
    len_aux: int = -1,
    len_core: int = -1,
) -> dict[str, float]:
    """
    Convert alpha & beta to aux_prc & core_prc.

    aux_prc is the proportion of finetuning that's aux:
        aux_prc = ft_aux / (ft_aux + ft_core) 
        alpha = ft_core / ft_aux
        thus...
        aux_prc = 1 / (1 + alpha)

    core_prc is the proportion of core-only run that's the base:
        ft_aux = len_aux * beta
        ft_core = ft_aux * alpha
        len_base = len_aux + len_core - ft_core - ft_aux
        len_base = len_core * core_prc
        thus...
        core_prc = 1 + (len_aux / len_core) * (1 - beta * (alpha + 1))

    alpha is the ratio of core to aux data during ft:
        alpha = ft_core / ft_aux
        aux_prc = ft_aux / (ft_aux + ft_core)
        thus...
        alpha = (1 - aux_prc) / aux_prc

    beta is the proportion of aux data used during training:
        aux_prc = ft_aux / (ft_aux + ft_core)
        ft_core = alpha * ft_aux
        ft_aux = beta * len_aux
        len_base = len_aux + len_core - ft_core - ft_aux
        len_base = len_core * core_prc
        thus...
        beta = aux_prc * (1 - (len_core / len_aux) * (core_prc - 1))
    """
    
    assert len_aux >= 0 and len_core >= 0, "len_aux and len_core must be >= 0"

    if aux_prc is None:
        if alpha is not None:
            aux_prc = 1 / (1 + alpha)
    
    if alpha is None:
        if aux_prc is not None:
            alpha = (1 - aux_prc) / aux_prc

    if beta is None:
        if aux_prc is not None and core_prc is not None:
            if len_aux > 0 and len_core > 0:
                beta = aux_prc * (1 - (len_core / len_aux) * (core_prc - 1))

    if core_prc is None:
        if alpha is not None and beta is not None:
            if len_aux > 0 and len_core > 0:
                core_prc = 1 + (len_aux / len_core) * (1 - beta * (alpha + 1))

    if alpha is None and aux_prc is None:
        if core_prc is not None and beta is not None:
            alpha = (1 - (core_prc - 1) * (len_core / len_aux)) / beta - 1
            aux_prc = 1 / (1 + alpha)

    return {"alpha": alpha, "beta": beta, "aux_prc": aux_prc, "core_prc": core_prc}

def log_batch_counts(batches: list[tuple[str, tuple, tuple] | str], logger: logging.Logger) -> None:
    batches = sorted(batches, key=lambda x: (x[1], x[0], x[2]))
    batch_counts = {}
    for batch in batches:
        if batch not in batch_counts:
            batch_counts[batch] = 0
        batch_counts[batch] += 1
    for batch, count in batch_counts.items():
        logger.info(f"Batch [{batch}] count: {count}")

def calc_lora_rank(
    model_config: "ModelConfig",
    lora_attn: bool,
    lora_mlp: bool,
    mlp_dim: int,
    aux_dim: int,
) -> int:
    
    num_heads = model_config.num_heads
    num_key_value = model_config.num_key_value
    embed_dim = model_config.embed_dim
    
    moe_aux_params = 2 * embed_dim * aux_dim + aux_dim + embed_dim
    
    head_dim = embed_dim // num_heads
    attn_lora_dim = (
        2 * embed_dim + # c_attn_q
        (embed_dim + 2 * num_key_value * head_dim) + # c_attn_kv
        2 * embed_dim # c_proj
    )
    
    mlp_lora_dim = 2 * (embed_dim + mlp_dim)
    
    total_lora_dim = 0
    if lora_attn:
        total_lora_dim += attn_lora_dim
    if lora_mlp:
        total_lora_dim += mlp_lora_dim
    
    lora_rank = max(1, round(moe_aux_params / total_lora_dim))

    return lora_rank

def get_routed_dims(
    model_config: "ModelConfig", 
    run_config: "RunConfig",
    expert_dist: str,
    aux_prc: float | None = None,
) -> tuple[int, int]:

    if expert_dist == "add":
        routed_mlp_dim = model_config.mlp_dim
        routed_aux_dim = model_config.mlp_dim // 8

    elif expert_dist == "equal_one":
        routed_mlp_dim = model_config.mlp_dim
        routed_aux_dim = model_config.mlp_dim

    elif expert_dist == "equal_sum":
        num_labels = len(run_config.aux_labels) + 1
        hid_dim = round(model_config.mlp_dim / num_labels)
        routed_mlp_dim = hid_dim
        routed_aux_dim = hid_dim

    elif expert_dist == "prc_one":

        k = 64
        assert model_config.mlp_dim % k == 0, "model_config.mlp_dim must be divisible by k"

        len_all_core = len(run_config.loaders["core"]["train"])
        len_all_aux = sum([len(run_config.loaders[label]["train"]) for label in run_config.aux_labels])
        len_all_data = len_all_core + len_all_aux
        if aux_prc is None:
            aux_prc = len_all_aux / len_all_data
        core_prc = 1 - aux_prc
        routed_mlp_dim = int(round(model_config.mlp_dim * core_prc) // k) * k # round to nearest multiple of k
        routed_aux_dim = model_config.mlp_dim - routed_mlp_dim

        assert routed_aux_dim >= 0, "routed_aux_dim must be >= 0"
        assert routed_mlp_dim + routed_aux_dim == model_config.mlp_dim, "routed_mlp_dim + routed_aux_dim must be equal to model_config.mlp_dim"
        assert routed_mlp_dim % k == 0, "routed_mlp_dim must be divisible by k"
        assert routed_aux_dim % k == 0, "routed_aux_dim must be divisible by k"

    elif expert_dist == "prc_sum":

        k = 64
        assert model_config.mlp_dim % k == 0, "model_config.mlp_dim must be divisible by k"

        num_aux_labels = len(run_config.aux_labels)
        num_labels = num_aux_labels + 1
        len_all_core = len(run_config.loaders["core"]["train"])
        len_all_aux = sum([len(run_config.loaders[label]["train"]) for label in run_config.aux_labels])
        len_all_data = len_all_core + len_all_aux
        if aux_prc is None:
            aux_prc = len_all_aux / len_all_data
        core_prc = 1 - aux_prc
        routed_mlp_dim = int(round(model_config.mlp_dim * core_prc) // k) * k # round to nearest multiple of k

        routed_aux_dim = model_config.mlp_dim - routed_mlp_dim
        routed_aux_dim = routed_aux_dim // num_aux_labels

        assert routed_aux_dim >= 0, "routed_aux_dim must be >= 0"
        assert routed_mlp_dim % k == 0, "routed_mlp_dim must be divisible by k"

    else:
        raise ValueError(f"Invalid expert distribution: {expert_dist}")

    return routed_mlp_dim, routed_aux_dim


def make_model(
    model_class: Transformer,
    model_config: ModelConfig,
    run_config: "RunConfig",
    extra_args: Optional[dict] = None,
) -> Transformer:

    device = run_config.device
    logger = run_config.logger
    do_compile = run_config.do_compile
    is_ddp = is_distributed()

    if extra_args is None:
        extra_args = dict()

    model = model_class(model_config, **extra_args)
    model = model.to(device, dtype=torch.bfloat16)

    log_model_params(model, logger)

    if do_compile:
        model = torch.compile(model, dynamic=True)

    if is_ddp:
        model = DDP(
            model,
            device_ids=[get_rank()],
            output_device=get_rank(),
            find_unused_parameters=True,
            gradient_as_bucket_view=True,
        )

    return model

def copy_model(model: Transformer, device: torch.device, do_compile: bool = False) -> Transformer:
    """
    Copy a model for finetuning or evaluation.
    """

    is_ddp = is_distributed()
    model = get_raw_model(model)
    copied_model = deepcopy(model)
    copied_model = copied_model.to(device, dtype=torch.bfloat16)

    if do_compile:
        copied_model = torch.compile(copied_model, dynamic=True)

    if is_ddp:
        copied_model = DDP(
            copied_model,
            device_ids=[get_rank()],
            output_device=get_rank(),
            find_unused_parameters=True,
            gradient_as_bucket_view=True,
        )
    
    return copied_model

def get_timestamp() -> str:
    return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")


def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def log_line(msg: dict, log_fp: Path) -> None:
    if is_main_process():
        with open(log_fp, "a") as f:
            f.write(json.dumps(msg, default=str) + "\n")


def get_batch(loader: DataLoader) -> tuple[torch.Tensor, torch.Tensor, str | None]:

    batch, label = loader.next_batch()
    x = batch[:, :-1]
    y = batch[:, 1:]

    return x, y, label

def log_model_params(model: Transformer, logger: logging.Logger) -> None:
    """Log the number of parameters in the model and for each label.
    
    Args:
        model: The Transformer model
        logger: Logger to log to
    """
    
    # Log total number of parameters and class name of the model
    model = get_raw_model(model)
    model_type = type(model).__name__
    num_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Initialized {model_type} with {num_params:,} total parameters")
    
    # Log parameters per label
    if model.model_type == "routed":

        labels = ["core"] + model.config.aux_labels
        if isinstance(model, DemixTransformer):
            labels += ["SHARED"]

        for label in labels:
            num_params = sum(p.numel() for p in model.get_params(label))
            logger.info(f"  Label '{label}' has {num_params:,} parameters")

def get_select_mask(
    labels: list[str],
    selected_labels: Optional[Iterable[str]],
    device: torch.device,
) -> torch.Tensor:
    """
    Create a boolean expert selection mask of length len(labels).
    """
    K = len(labels)
    mask = torch.zeros(K, device=device, dtype=torch.bool)

    # If selecting all when None provided
    if selected_labels is None:
        mask[:] = True
    else:
        label_set = set(labels)
        for e in selected_labels:
            assert e in label_set, f"Unknown expert label '{e}' not in {labels}"
            mask[labels.index(e)] = True

    return mask


def restore_state(
    model: Transformer,
    checkpoint_path: str,
    device: torch.device,
    logger: logging.Logger,
) -> tuple[Transformer, dict]:
    """
    Load a checkpoint into a model and return the model and state dict.
    
    The returned state dict contains optimizer states, scheduler info, 
    step counts, etc. that can be passed to training functions for resumption.
    
    Returns:
        Tuple of (model with loaded weights, checkpoint state dict)
    """
    logger.info(f"Restoring state from: {checkpoint_path}")
    
    state = torch.load(checkpoint_path, map_location=device)
    raw_model = get_raw_model(model)
    raw_model.load_state_dict(state['model'])
    
    logger.info(f"Restored {len(state['model'])} model params, step {state.get('step', '?')}/{state.get('total_steps', '?')}")

    del state['model'] #avoid redundant model weights in state dict
    
    return model, state

def set_seeds(seed: int) -> None:
    #set random seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)