import transformers
import torch
from torch import nn
import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutput
import wandb
import numpy as np
from llm_fairness import MODELS


def from_name(name, pretrained, tokenizer=None, max_length=None, task=None):

    if name not in MODELS:
        raise ValueError(f"`{name}` must be in {MODELS}")

    config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)

    if max_length is not None:
        if "apple" in name:
            config.rope_max_length = max_length
        else:
            config.max_position_embeddings = max_length

    if tokenizer:
        config.eos_token_id = tokenizer.eos_token_id
        config.bos_token_id = tokenizer.bos_token_id
        config.pad_token_id = tokenizer.pad_token_id
        config.vocab_size = len(tokenizer.vocab)

    if pretrained:
        model = transformers.AutoModelForCausalLM.from_pretrained(
            config._name_or_path,
            config=config,
            trust_remote_code=True,
            ignore_mismatched_sizes=True,
        )
    else:
        model = transformers.AutoModelForCausalLM.from_config(
            config=config, trust_remote_code=True
        )

    if task == "lm":
        return model

    if "OpenELM" in name:
        model = model.transformer
    else:
        model = model.model

    return model


def from_config(config, pretrained, task=None):

    if config._name_or_path not in MODELS:
        raise ValueError(f"`{config._name_or_path}` must be in {MODELS}")

    if pretrained:
        model = transformers.AutoModelForCausalLM.from_pretrained(
            config._name_or_path,
            config=config,
            trust_remote_code=True,
            ignore_mismatched_sizes=True,
        )
    else:
        model = transformers.AutoModelForCausalLM.from_config(
            config=config, trust_remote_code=True
        )

    if task == "lm":
        return model

    if "OpenELM" in config._name_or_path:
        model = model.transformer
    else:
        model = model.model
    return model
