import inspect
import os

import torch
from lion_pytorch import Lion
from torch.nn import functional as F
import wandb


@torch.no_grad()
def validate(
    tokenizer,
    model,
    accelerator,
    t_max,
    t_min,
    step,
    wandb=None,
    a_lm_model_name="gpt2",
    tw_model=None,
    simplified_inputs: bool = False,
):
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        AutoModelForSequenceClassification,
    )
    from ddlm.sampler.euler import (
        get_sigmas_karras,
        sample_euler,
    )

    if os.path.exists("cache_models"):
        tokenizer_g = AutoTokenizer.from_pretrained(
            os.path.join("cache_models", "textattack/roberta-base-CoLA")
        )
        model_g = AutoModelForSequenceClassification.from_pretrained(
            os.path.join("cache_models", "textattack/roberta-base-CoLA")
        )
        a_lm_model = AutoModelForCausalLM.from_pretrained(
            os.path.join("cache_models", a_lm_model_name)
        )
        a_lm_tokenizer = AutoTokenizer.from_pretrained(
            os.path.join("cache_models", a_lm_model_name)
        )
    else:
        tokenizer_g = AutoTokenizer.from_pretrained("textattack/roberta-base-CoLA")
        tokenizer_g.save_pretrained(
            os.path.join("cache_models", "textattack/roberta-base-CoLA")
        )
        model_g = AutoModelForSequenceClassification.from_pretrained(
            "textattack/roberta-base-CoLA"
        )
        model_g.save_pretrained(
            os.path.join("cache_models", "textattack/roberta-base-CoLA")
        )
        a_lm_model = AutoModelForCausalLM.from_pretrained(a_lm_model_name)
        a_lm_model.save_pretrained(os.path.join("cache_models", a_lm_model_name))
        a_lm_tokenizer = AutoTokenizer.from_pretrained(a_lm_model_name)
        a_lm_tokenizer.save_pretrained(os.path.join("cache_models", a_lm_model_name))

    sigmas = get_sigmas_karras(
        100, t_min, t_max, tw_model=tw_model, device=accelerator.device
    )

    batch_size = 10
    input_ids = torch.LongTensor([[tokenizer.pad_token_id] * 64] * batch_size)
    conditioning_mask = torch.zeros_like(input_ids, dtype=torch.bool)
    outputs, _ = sample_euler(
        model=model,
        simplified_inputs=simplified_inputs,
        sigmas=sigmas.to(accelerator.device),
        input_ids=input_ids.to(accelerator.device),
        conditioning_mask=conditioning_mask.to(accelerator.device),
        disable=True,
    )
    examples = []
    a_loss = 0

    for i in range(batch_size):
        example = tokenizer.decode(outputs.logits[i].argmax(-1))
        examples.append(example)

        batch = a_lm_tokenizer(example, return_tensors="pt")
        batch["labels"] = batch["input_ids"]
        loss = a_lm_model(**batch).loss
        a_loss += loss.item()

    dist1, dist2, dist3 = [], [], []
    unigrams, bigrams, trigrams = set(), set(), set()
    total_words = 0
    for q in examples:
        o = q.split(" ")
        total_words += len(o)
        unigrams.update(o)
        for i in range(len(o) - 1):
            bigrams.add(o[i] + "_" + o[i + 1])
        for i in range(len(o) - 2):
            trigrams.add(o[i] + "_" + o[i + 1] + "_" + o[i + 2])
    dist1 = len(unigrams) / total_words
    dist2 = len(bigrams) / total_words
    dist3 = len(trigrams) / total_words

    with open(f"examples_step_{step}.txt", "w") as out:
        out.writelines([e + "\n" for e in examples])

    inputs_g = tokenizer_g(examples, return_tensors="pt", padding=True, truncation=True)

    outputs_g = model_g(**inputs_g)
    probs_g = F.softmax(outputs_g.logits, dim=-1)[:, 1].mean().item()

    if wandb is not None:
        wandb.log(
            {
                "alm_loss": a_loss,
                "grammar": probs_g,
                "dist_1": dist1,
                "dist_2": dist2,
                "dist_3": dist3,
            },
            step=step,
            commit=False,
        )
        wandb.save(f"examples_step_{step}.txt")

def get_optimizer(model, lr, betas, weight_decay, optimizer="adamw"):
    assert optimizer in ["adamw", "lion"]
    no_decay_2d = []  # ["embeddings.weight"]
    # start with all of the candidate parameters
    param_dict = {pn: p for pn, p in model.named_parameters()}
    # filter out those that do not require grad
    param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
    # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
    # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2 and (n not in no_decay_2d)]
    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2 or (n in no_decay_2d)]
    optim_groups = [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': nodecay_params, 'weight_decay': 0.0}
    ]
    num_decay_params = sum(p.numel() for p in decay_params)
    num_nodecay_params = sum(p.numel() for p in nodecay_params)
    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
    print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
    if optimizer == "adamw":
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, weight_decay=weight_decay, **extra_args)
        print(f"using fused AdamW: {use_fused}")
    if optimizer == "lion":
        optimizer = Lion(optim_groups, lr=lr, betas=betas, weight_decay=weight_decay, use_triton=True)
        print("using Lion with triton")
    return optimizer
