import json
import warnings
from collections import namedtuple
from pathlib import Path
from typing import Any
import sys

import torch
import wandb
from torch import nn

from theory.theory import create_embedding_matrix

# sys.path.append("..")

available_mamba_architectures = []
try:
    from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
    from mamba_ssm.models.config_mamba import MambaConfig
    from experiments.mamba_with_dropout import MambaWithDropout
    available_mamba_architectures.extend(['mamba_ssm'])
except Exception:
    # warnings.warn("could not import 'mamba_ssm'")
    pass
from mamba_tiny.model import ModelArgs, Mamba as MambaTiny
available_mamba_architectures.extend(['mamba_tiny'])
from theory.simplified_linear_mamba import SimplifiedLinearMamba, set_simplified_linear_model_ideal_weights

available_mamba_architectures.extend(['mamba_linear'])

CONFIG_DIR = Path(__file__).parent

MODELS = "models"
DATASETS = "datasets"
WANDB = "wandb"

ARGS = "args"
KWARGS = "kwargs"


def get_wandb_api_key():
    file = CONFIG_DIR / WANDB / "api_key.txt"
    return file.read_text().strip() if file.exists() else None

def get_wandb_entity():
    file = CONFIG_DIR / WANDB / "entity.txt"
    return file.read_text().strip() if file.exists() else None


def wandb_login():
    wandb.login(
        key=get_wandb_api_key(),
        relogin=True, force=True,
    )


def get_model_arg_and_kwargs_from_config(model_class: str, model_variant: str):

    # get master config
    model_variant_config_path = CONFIG_DIR / MODELS / model_class / f"{model_variant}.json"
    model_variant_config = json.loads(model_variant_config_path.read_text())

    # get sub configs
    model_args_config = model_variant_config.get('model_args', None)
    model_kwargs_config = model_variant_config.get('model_kwargs', None)
    model_ssm_kwargs = model_variant_config.get('ssm_kwargs', None)
    model_weights_init = model_variant_config.get('weights_init', None)

    if model_kwargs_config is not None:
        model_kwargs_config = list(model_kwargs_config)

    # get model args
    model_args = ()
    if model_args_config is not None:
        model_args = json.loads((CONFIG_DIR / MODELS / model_class / ARGS / f"{model_args_config}.json").read_text())

    # get model kwargs
    model_kwargs = {}
    if model_kwargs_config is not None:
        for name in model_kwargs_config:
            path = CONFIG_DIR / MODELS / model_class / KWARGS / f"{name}.json"
            data = json.loads(path.read_text())
            model_kwargs.update(data)

    # add ssm kwargs
    if model_ssm_kwargs is not None:
        model_kwargs.update(model_ssm_kwargs)

    return model_args, model_kwargs, model_weights_init


def get_model_from_config(V: int, D: int, N: int, model_class: str, model_variant: str, dropout_rate: float = 0):

    # get arge and kwargs
    model_args, model_kwargs, model_weights_init = get_model_arg_and_kwargs_from_config(
        model_class=model_class,
        model_variant=model_variant)

    if model_class not in available_mamba_architectures:
        raise ValueError(
            f"{model_class = } is currently unavailable; "
            f"choose from: {available_mamba_architectures}\n"
        )

    # set model dimensions, and init weights if requested
    if model_class == 'mamba_ssm':

        # set dims
        model_args['vocab_size'] = V
        model_args['d_model'] = D
        model_args['ssm_cfg']['d_state'] = N

        # instantiate
        if dropout_rate == 0:
            model = MambaLMHeadModel(config=MambaConfig(**model_args), **model_kwargs)
        else:
            # print(f"found {dropout_rate = }; using MambaWithDropout")
            model = MambaWithDropout(config=MambaConfig(**model_args), dropout_rate=dropout_rate, **model_kwargs)

    elif model_class == 'mamba_tiny':

        # set dims
        model_args['vocab_size'] = V
        model_args['d_model'] = D
        model_args['d_state'] = N

        # instantiate
        model = MambaTiny(args=ModelArgs(**model_args), **model_kwargs)

    elif model_class == 'mamba_linear':

        model = SimplifiedLinearMamba(V=V, D=D, N=N)

    else:
        raise ValueError(f"unknown {model_class = }")

    # if required, initialize weights
    if (weights_config := model_weights_init) is not None:
        _set_model_weights_from_config(model, model_class, weights_config, V=V, D=D, N=N)

    return model


def _set_model_weights_from_config(
        model: nn.Module,
        model_class: str,
        weights_config: dict[str, str],
        V=None, D=None, N=None,
):

    freeze_initialized = weights_config.get("freeze_initialized_weights", False)

    if model_class == 'mamba_linear':
        if weights_config.get("set_ideal_weights", False):
            set_simplified_linear_model_ideal_weights(model)
            return

    if (E_embedding_type := weights_config.get("init_E", None)) is not None:
        E = create_embedding_matrix(V=V, M=D, embedding_type=E_embedding_type)  # (V, D)

        if model_class == 'mamba_tiny':
            with torch.no_grad():
                model.embedding.weight.copy_(E)
                model.lm_head.weight.copy_(E)
            if freeze_initialized:
                _freeze_modules(modules=[model.embedding, model.lm_head])

        elif model_class == 'mamba_ssm':
            with torch.no_grad():
                model.backbone.embedding.weight.copy_(E)
                model.lm_head.weight.copy_(E)
            if freeze_initialized:
                _freeze_modules(modules=[model.backbone.embedding, model.lm_head])

    if (S_embedding_type := weights_config.get("init_S", None)) is not None:
        S = create_embedding_matrix(V=D, M=N, embedding_type=S_embedding_type)  # (D, N)
        raise NotImplementedError


def _freeze_modules(modules: list[nn.Module]):

    for module in modules:

        for p in module.parameters():
            p.requires_grad = False


def get_dataset_kwargs_from_dataset_config(dataset: str):
    return json.loads((CONFIG_DIR / DATASETS / f"{dataset}.json").read_text())


def get_dataset_config_by_split_from_dataset_config(dataset_config: str, seed: int = None):

    splits = ["train", "val", "test"]
    dataset_config_by_split: dict[str, dict[str, Any]] = {}

    for i, k in enumerate(splits):

        cfg = {}

        cfg.update(dataset_config.get("kwargs", {}))

        cfg["V"] = dataset_config["V"]
        cfg["L"] = dataset_config["L"]
        cfg["batch_size"] = dataset_config["batch_size"]
        cfg["N_facts"] = dataset_config["N_facts"]
        cfg["dataset_size"] = dataset_config["split_size"][k]

        # set a different seed for each split
        cfg["seed"] = (seed + i) if seed is not None else None
        # print(f"{cfg['seed'] = }")

        dataset_config_by_split[k] = cfg

    return dataset_config_by_split
