import os
import warnings
import json
import time
import copy
import time
from dataclasses import dataclass
from typing import Optional, Tuple, Union, Literal

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from transformers import DynamicCache

from embodied_cd.trl.models.type_aliases import (
    GenerationOutput,
    GenerationOutputWithCache,
)
from embodied_cd.common.env_utils import build_skill_list

_Type_Decoding = Literal["greedy", "beam-action"]


def generation(model, tokenizer, query, max_length=80, **gen_params):
    device = model.device
    input_ids = tokenizer.encode(query, return_tensors="pt").to(device)
    input_length = input_ids.shape[-1]

    input_ids = model.generate(
        input_ids,
        **gen_params,
        max_new_tokens=max_length,
        pad_token_id=tokenizer.eos_token_id,
    )
    input_ids = input_ids.squeeze()[input_length:]
    response = tokenizer.decode(input_ids, skip_special_tokens=True)

    return GenerationOutput(
        query=query,
        response=response,
        values=None,
        prob=None,
    )


def greedy_generation(model, tokenizer, query, max_length=80):
    device = model.device
    # tokenizer encoding
    input_ids = tokenizer.encode(query, return_tensors="pt").to(device)
    input_length = input_ids.shape[-1]
    prob = 1.0
    log_probs = []
    for i in range(max_length):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]

        # smaple
        probs = torch.nn.functional.softmax(logits, dim=-1)
        # next_token = torch.multinomial(probs, num_samples=1).squeeze(1)

        # argmax
        next_token = torch.argmax(logits, dim=-1)
        prob = prob * probs[0][next_token.item()]
        log_probs.append(torch.log(probs[0][next_token.item()]).item())
        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    # sequeeze
    input_ids = input_ids.squeeze()[input_length:]
    # tokenizer decoding
    response = tokenizer.decode(input_ids, skip_special_tokens=False)
    return GenerationOutput(
        query=query,
        response=response,
        values=None,
        prob=prob,
        log_probs=log_probs,
    )


def control_generation(model, tokenizer, query, max_length=80, top_k=30, top_p=1.0):
    device = model.device
    # tokenizer encoding
    input_ids = tokenizer.encode(query, return_tensors="pt").to(device)
    input_length = input_ids.shape[-1]
    for i in range(max_length):
        tt1 = time.time()
        outputs = model(input_ids)
        next_token_logits = outputs.logits[:, -1, :]
        v_value = outputs.values.unsqueeze(-1)[:, -1, :]
        probs = F.softmax(next_token_logits, dim=-1)
        # select top_k token from logits
        candidate_probs, candidate_tokens = torch.topk(probs, top_k, dim=-1)

        # Distribution Shift!
        t1 = time.time()
        next_input_ids = []
        for q_token in candidate_tokens[0]:
            next_input_ids.append(torch.cat([input_ids, q_token.view([1, 1])], dim=-1))
        next_input_ids = torch.cat(next_input_ids, dim=0)
        model.set_adapter("adapter_3")
        q_values = model(next_input_ids).values.unsqueeze(-1)[:, -1, :]
        model.set_adapter("adapter_2")
        t2 = time.time()
        for j, (q_token, q_value) in enumerate(zip(candidate_tokens[0], q_values)):
            # print(f"Prob_1: {candidate_probs[0][j]}")
            # probs[0][q_token.item()] = probs[0][q_token.item()] * (torch.exp(q_value))
            candidate_probs[0][j] = candidate_probs[0][j] * (
                torch.nn.Sigmoid()(q_value) / torch.nn.Sigmoid()(v_value)
            )
            # candidate_probs[0][j] = (candidate_probs[0][j] * (torch.nn.Sigmoid()(q_value) / torch.nn.Sigmoid()(v_value)) - candidate_probs[0][j])
            # print(f"Prob_2: {candidate_probs[0][j]}")
        prob = candidate_probs
        # print("="*20)
        t3 = time.time()

        next_prob, next_token = torch.topk(probs, 1, dim=-1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        tt2 = time.time()

        if next_token.item() == tokenizer.eos_token_id:
            break

    input_ids = input_ids.squeeze()[input_length:]
    response = tokenizer.decode(input_ids, skip_special_tokens=True)
    return GenerationOutput(
        query=query,
        response=response,
        values=None,
        prob=None,
    )


def beam_token_generation(model, tokenizer, query, ngram_dict, object_list):
    device = model.device
    # tokenizer encoding
    input_ids = tokenizer.encode(query, return_tensors="pt").to(device)
    input_length = input_ids.shape[-1]
    while True:  # iterate
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]
        probs = torch.nn.functional.softmax(logits, dim=-1)

        if list(ngram_dict.keys())[0] != "noun":
            ngram_keys = list(ngram_dict.keys())
        else:
            ngram_keys = object_list

        ngram_list, ngram_probs = [], []
        # print("=" * 30)
        for ngram in ngram_keys:
            # add space infromt of the ngram
            ngram_ids = tokenizer.encode(" " + ngram, return_tensors="pt").to(device)
            ngram_list.append(ngram_ids)

            _prob = probs[0][ngram_ids[0][0].item()]
            if ngram_ids.shape[-1] > 1:
                temp_input_ids = torch.cat([input_ids, ngram_ids], dim=-1)
                temp_outputs = model(temp_input_ids)
                temp_logits = temp_outputs.logits[0, -ngram_ids.shape[-1] : -1, :]
                temp_probs = torch.nn.functional.softmax(temp_logits, dim=-1)
                for i, nid in enumerate(ngram_ids[0][1:]):
                    _prob = _prob * temp_probs[i][nid.item()]
            ngram_probs.append(_prob)
            # print(ngram, _prob)
        ngram_probs = torch.Tensor(ngram_probs)
        # print(ngram_probs)
        next_token_idx = torch.argmax(ngram_probs, dim=-1).item()
        next_token = ngram_list[next_token_idx]
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        # print(next_token_idx)
        # print("=" * 30)

        if list(ngram_dict.keys())[0] != "noun":
            ngram_dict = ngram_dict[list(ngram_dict.keys())[next_token_idx]]
        else:
            ngram_dict = ngram_dict["noun"]

        if not isinstance(ngram_dict, dict):
            break

    input_ids = input_ids.squeeze()[input_length:]
    response = tokenizer.decode(input_ids, skip_special_tokens=True)
    response = response.strip()
    return GenerationOutput(
        query=query,
        response=response,
        values=None,
        prob=None,
    )


def beam_action_generation(
    model,
    tokenizer,
    query,
    action_format,
    object_list,
    can_list=None,
    env_name="virtualhome",
    reduced_skills=True,
):
    skill_list = build_skill_list(action_format, object_list, env_name, reduced_skills)
    if can_list is not None:
        if env_name == "virtualhome":
            can_put_place_list = [
                c for c in can_list if c.startswith("put") or c.startswith("place")
            ]
            skill_put_place_list = [
                c for c in skill_list if c.startswith("put") or c.startswith("place")
            ]

            put_place_list = []
            for skill in skill_put_place_list:
                for can in can_put_place_list:
                    pre, suf = can.split()
                    if skill.startswith(pre) and skill.endswith(suf):
                        put_place_list.append(skill)
                        break

            skill_list = list((set(skill_list) & set(can_list)) | set(put_place_list))

        if env_name == "alfred":
            skill_list = list(set(skill_list) & set(can_list))

    # query_ids = (
    #     tokenizer.apply_chat_template(
    #         [{"role": "user", "content": query}],
    #         tokenize=True,
    #         add_generation_prompt=True,
    #         return_tensors="pt",
    #     )
    #     .to(model.device)
    #     .repeat(len(skill_list), 1)
    # )
    query_ids = (
        tokenizer.encode(query, return_tensors="pt")
        .to(model.device)
        .repeat(len(skill_list), 1)
    )
    query_length = query_ids.shape[-1]

    output = tokenizer(skill_list, return_tensors="pt", padding=True)
    skill_ids = output["input_ids"].to(model.device)
    skill_masks = output["attention_mask"].to(model.device)
    skill_lengths = skill_masks.sum(dim=-1)

    input_ids = torch.concat([query_ids, skill_ids], dim=-1)
    attention_mask = torch.concat(
        [
            torch.ones_like(query_ids),
            skill_masks,
        ],
        dim=-1,
    )

    batch_size = 1
    action_probs = []
    for i in range(0, len(skill_list), batch_size):
        inputs = {
            "input_ids": input_ids[i : i + batch_size],
            "attention_mask": attention_mask[i : i + batch_size],
        }
        model_output = model(**inputs)

        for j in range(len(inputs["input_ids"])):
            logits = model_output.logits[j, query_length - 1 :]
            probs = torch.nn.functional.softmax(logits, dim=-1)
            tokens = skill_ids[i + j][: skill_lengths[i + j]]

            prob = 1.0
            for k, token in enumerate(tokens):
                prob *= probs[k][token.item()]
            action_probs.append(prob.item())

    # print(sorted(zip(skill_list, action_probs), key=lambda x: x[1], reverse=True)[:5])

    max_index = np.argmax(action_probs)
    return GenerationOutput(
        query=query,
        response=skill_list[max_index],
        values=None,
        prob=action_probs[max_index],
    )


def generation_using_cache(
    model, tokenizer, query, max_length=80, prompt_cache=None, **gen_params
):
    device = model.device
    input_ids = tokenizer.encode(query, return_tensors="pt").to(device)
    input_length = input_ids.shape[-1]

    output = model.generate(
        input_ids,
        **gen_params,
        max_new_tokens=max_length,
        past_key_values=prompt_cache,
        return_dict_in_generate=True,
    )
    input_ids = output.sequences.squeeze()[input_length:]
    response = tokenizer.decode(input_ids, skip_special_tokens=True)
    prompt_cache = output.past_key_values

    return GenerationOutputWithCache(
        query=query,
        response=response,
        values=None,
        prompt_cache=prompt_cache,
    )


def beam_token_generation_using_cache(
    model, tokenizer, query, ngram_dict, object_list, prompt_cache=None
):
    device = model.device
    # tokenizer encoding
    input_ids = tokenizer.encode(query, return_tensors="pt").to(device)
    input_length = input_ids.shape[-1]

    while True:  # iterate
        outputs = model(input_ids, past_key_values=copy.deepcopy(prompt_cache))
        logits = outputs.logits[:, -1, :]
        probs = torch.nn.functional.softmax(logits, dim=-1)

        if list(ngram_dict.keys())[0] != "noun":
            ngram_keys = list(ngram_dict.keys())
        else:
            ngram_keys = object_list

        ngram_list, ngram_probs = [], []
        # print("=" * 30)
        for ngram in ngram_keys:
            # add space infromt of the ngram
            ngram_ids = tokenizer.encode(" " + ngram, return_tensors="pt").to(device)
            ngram_list.append(ngram_ids)

            _prob = probs[0][ngram_ids[0][0].item()]
            if ngram_ids.shape[-1] > 1:
                temp_input_ids = torch.cat([input_ids, ngram_ids], dim=-1)
                temp_outputs = model(
                    temp_input_ids, past_key_values=copy.deepcopy(prompt_cache)
                )
                temp_logits = temp_outputs.logits[0, -ngram_ids.shape[-1] : -1, :]
                temp_probs = torch.nn.functional.softmax(temp_logits, dim=-1)
                for i, nid in enumerate(ngram_ids[0][1:]):
                    _prob = _prob * temp_probs[i][nid.item()]
            ngram_probs.append(_prob)
            # print(ngram, _prob)
        ngram_probs = torch.Tensor(ngram_probs)
        # print(ngram_probs)
        next_token_idx = torch.argmax(ngram_probs, dim=-1).item()
        next_token = ngram_list[next_token_idx]
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        # print(next_token_idx)
        # print("=" * 30)

        if list(ngram_dict.keys())[0] != "noun":
            ngram_dict = ngram_dict[list(ngram_dict.keys())[next_token_idx]]
        else:
            ngram_dict = ngram_dict["noun"]

        if not isinstance(ngram_dict, dict):
            break

    input_ids = input_ids.squeeze()[input_length:]
    response = tokenizer.decode(input_ids, skip_special_tokens=True)
    response = response.strip()
    return GenerationOutput(
        query=query,
        response=response,
        values=None,
        prob=None,
    )


def control_beam_token_generation(
    model, tokenizer, reward_queries, plan_queries, action_list
):
    device = model.device
    values, probs = [], []
    for reward_query, plan_query, action in zip(
        reward_queries, plan_queries, action_list
    ):
        # value
        input_ids = tokenizer.encode(reward_query, return_tensors="pt").to(device)
        num_logits_to_keep = (
            tokenizer.encode(action + "<|im_end|>", return_tensors="pt")
            .squeeze()
            .shape[-1]
        )
        value = model(
            input_ids, num_logits_to_keep=num_logits_to_keep, average_pool=True
        ).values.squeeze()
        values.append(value)

        # logit
        input_ids = tokenizer.encode(plan_query, return_tensors="pt").to(device)
        input_length = input_ids.shape[-1]
        action_ids = tokenizer.encode(action, return_tensors="pt").to(device)
        logits = model(torch.cat([input_ids, action_ids], dim=-1)).logits[
            :, input_length:, :
        ]
        # logits = torch.nn.functional.softmax(logits, dim=-1)
        prob = 1.0
        for i, action_token in enumerate(action_ids[0]):
            prob += logits[0][i][action_token.item()]
        probs.append(prob)

        print(action, value, prob)

    values = torch.Tensor(values)
    values = torch.nn.functional.softmax(values, dim=-1)
    probs = torch.Tensor(probs)
    probs = torch.nn.functional.softmax(probs, dim=-1)

    q_values = values * probs

    print(values, probs)
    print(q_values)
    _, next_token = torch.topk(q_values, 1, dim=-1)
    response = action_list[next_token]
    return GenerationOutput(
        query=None,
        response=response,
        values=None,
        prob=None,
    )
