import torch
import numpy as np
import random
import os
import logging
import hydra
import wandb

from omegaconf import DictConfig, OmegaConf

from src.data_gen import Sampler


def set_random_seed(seed, verbose=False):
    if verbose:
        logging.info(f"Setting random seed to {seed}")

    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def init_model(cfg: DictConfig, sampler: Sampler, grad_tracker_on: bool):
    cfg["model"]["max_len"] = sampler.get_max_len() + 1
    cfg["model"]["eager_attn"] = cfg["model"]["myopic"] or grad_tracker_on
    cfg["model"]["vocab_size"] = sampler.vocab_size

    logging.info(f"Initializing model with eager_attn={cfg['model']['eager_attn']}")

    return hydra.utils.instantiate(cfg["model"]).to(cfg["device"])


def prepare_verify_probes_list(cfg: DictConfig) -> bool:
    assert not (cfg["autogenerate_verify_probes"] and len(cfg["verify_probes"]) > 0), \
        "Cannot have both verify_probes and autogenerate_verify_probes"
    if cfg["autogenerate_verify_probes"]:
        probes_list = []
        for coord_name in cfg["autogenerate_coord_names"]:
            for layer_idx in cfg["autogenerate_layer_indices"]:
                for seq_idx in cfg["autogenerate_sequence_indices"]:
                    probes_list.append((coord_name, layer_idx, seq_idx))
        cfg["verify_probes"] = probes_list
        cfg["autogenerate_verify_probes"] = False
        logging.info(f"Autogenerated verify_probes: {probes_list}")

        if wandb.run:
            wandb.config.update(cfg, allow_val_change=True)

    return len(cfg["verify_probes"]) > 0
