# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Optional, Tuple
from dataclasses import dataclass
import math
from typing import List
import torch
from torch import nn
import torch.nn.functional as F
import copy
import random
import json
import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.layers import (
    ParallelEmbedding,
    RowParallelLinear,
    ColumnParallelLinear,
)
from transformers import AutoTokenizer, AutoModel
import re


    
class FunctionLM(nn.Module):
    def __init__(self, base_model, tokenizer, func_dict, load_path=None, inference_mode="func_embedding"):
        super().__init__()
        self.inference_mode = inference_mode
        self.model = base_model
        self.tokenizer = tokenizer
        self.func_dict = func_dict
        self.func_list = {v: k for k, v in func_dict.items()}
        # self.func_embed = ColumnParallelLinear(
        #     base_model.params.dim, len(func_list), bias=False, init_method=lambda x: x
        # )
        self.func_embed = nn.Linear(base_model.params.dim, len(func_dict), bias=False).to("cuda")
        if load_path is not None and load_path != "None": # load func_embed weights
            embedding = torch.load(load_path)
            if isinstance(embedding, torch.Tensor):
                embedding = embedding.to("cuda")
                embedding = {"weight": embedding}

            # truncate the embedding if necessary
            if embedding["weight"].shape[0] > len(func_dict):
                print(f"Truncated the function embedding from {embedding['weight'].shape[0]} to {len(func_dict)}")
                embedding["weight"] = embedding["weight"][:len(func_dict)]

            self.func_embed.load_state_dict(embedding)
        
        # set the basemodel to eval mode and freeze the weights
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
        self.logits_bias = 0

    def set_bias(self, logits_bias):
        self.logits_bias = logits_bias

    def get_loss(self, raw_inputs, only_functoken=False, crop_prompt=False):
        
        assert len(raw_inputs) == 1
        raw_inputs = raw_inputs[0]

        # inputs: starts with <bos>, ends without <eos>, (bsz, seqlen)
        # labels: starts without <bos>, ends with <eos>, (bsz, seqlen)
        with torch.no_grad():
            # prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=True) for x in raw_inputs]

            if crop_prompt:
                prompt= "A:".join(raw_inputs['text'].split("A:")[:-1])+ "A:"
                # answer = raw_inputs['text'].split("A:")[-1]
                prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
                # answer_tokens = self.tokenizer.encode(answer, bos=False, eos=True)


            raw_input_ids = torch.tensor(self.tokenizer.encode(raw_inputs["text"], bos=True, eos=True))[:]
            labels = torch.tensor(self.tokenizer.encode(raw_inputs["text"], bos=True, eos=True))[:]

            if crop_prompt:
                assert prompt in raw_inputs["text"]

            if "tar_eq" not in raw_inputs:
                raw_inputs["tar_eq"] = ["<" + raw_inputs["api"] + ">"]
                # print(raw_inputs, raw_inputs["tar_eq"])
            for s, t, eq in zip(raw_inputs["start_token_idx"], raw_inputs["end_token_idx"], raw_inputs["tar_eq"]):
                
                if "[" in eq:
                    op = re.search(r"(\[.*?\])", eq).group(1)
                elif "<" in eq:
                    
                    op = re.search(r"(<.*?>)", eq).group(1)
                    # print(op)

                if op not in self.func_dict:
                    op = op[1:-1]
                labels[s] = self.func_dict[op] + 32000
                labels[s+1: t] = -100
            
            if crop_prompt:
                labels[:len(prompt_tokens)] = -100
            # labels = labels[1:]
            if only_functoken:
                labels[labels < 32000] = -100
            inputs = raw_input_ids[:-1].expand(1, -1).to("cuda")
            labels = labels[1:].expand(1, -1).to("cuda")

            last_logits, h = self.model(inputs, 0) # h: (bsz, seqlen, dim)
            token_logits = self.model.output(h) # (bsz, seqlen, vocab_size)
            # print(h.device)
        
        func_logits = self.func_embed(h.float()) # (bsz, seqlen, len(func_list))
        
        concat_logits = torch.cat([token_logits, func_logits], dim=-1) # (bsz, seqlen, vocab_size + len(func_list))
        loss = F.cross_entropy(concat_logits.view(-1, concat_logits.shape[-1]), labels.view(-1), ignore_index=-100)
        # check p, r, f1 for each function
        pred = torch.argmax(concat_logits, dim=-1) # (bsz, seqlen)
        pred = pred.view(-1)
        labels = labels.view(-1)
        if crop_prompt:
            pred = pred[labels != -100]
            labels = labels[labels != -100]

        label_funcs = [labels == self.func_dict[op] + 32000 for op in self.func_dict.keys()]
        pred_funcs = [pred == self.func_dict[op] + 32000 for op in self.func_dict.keys()]
        label_funcs = torch.stack(label_funcs, dim=0)
        pred_funcs = torch.stack(pred_funcs, dim=0)
        # (len(func_list), seqlen)
        # true positive
        tp = torch.sum(label_funcs * pred_funcs, dim=-1).detach().cpu().numpy()
        pred_funcs = torch.sum(pred_funcs, dim=-1).detach().cpu().numpy()
        # strange bug: if I use `pred` as variable name, the returned results will be all zeros
        true = torch.sum(label_funcs, dim=-1).detach().cpu().numpy()
        results = {
            "tp": tp,
            "pred": pred_funcs,
            "true": true
        }


        return loss, results
    
    @torch.no_grad()
    def generate(
        self,
        prompts: List[str],
        max_gen_len: int,
        temperature: float = 0.8,
        top_p: float = 0.95,
        stop_token: List[int] = [], # 29897: ), 3892: )=
        return_top: int = 0,
        disable_func: List[str] = [],
        disable_token: List[int] = [], # 29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
        no_left_parens: bool = False,
        
        objs: List[str] = [],
    ) -> List[str]:
        
        # default_disable = ["[STANDUP]", "[WAKEUP]", "[SLEEP]", "[PUTBACK]"]
        # for func in default_disable:
        #     if func in self.func_dict.keys():
        #         disable_func.append(func)

        bsz = len(prompts)

        left_state = False
        left_idx = 529
        right_idx = 29958

        # print("objs", objs)

        obj_encodings = [self.tokenizer.encode("<"+obj+">", bos=False, eos=False)[1:-1] for obj in objs]
        # print("obj encoding", obj_encodings)
        assert bsz == 1
        stop_token_substr = [torch.tensor(x).cuda().long() for x in stop_token if isinstance(x, list)]
        stop_token_single = [x for x in stop_token if isinstance(x, int)]
        
        func_list = list(self.func_dict.keys())

        # tokenize all the func in func_list
        func_tokens = [self.tokenizer.encode(x[1:-1], bos=False, eos=False) for x in func_list]
        
        generation_log = [] # (token, [(token, logits, prob)])
        params = self.model.params
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

        min_prompt_size = min([len(t) for t in prompt_tokens])
        max_prompt_size = max([len(t) for t in prompt_tokens])

        total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

        tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t).long()
        input_text_mask = tokens != self.tokenizer.pad_id
        start_pos = min_prompt_size
        prev_pos = 0
        hs = []
        
        for cur_pos in range(start_pos, total_len):
            _, h = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
            logits_token = self.model.output(h[:, -1, :]).float() # (bsz, vocab_size)
            logits_func = self.func_embed(h[:, -1, :].float()) # (bsz, len(func_list))
            if self.inference_mode != "func_embedding":
                logits_func = torch.zeros_like(logits_func) - 1e5
            
            # set logits of disable tokens to -1e5
            if len(disable_token) > 0:
                logits_token[:, disable_token] = -1e5

            # top 3 logits
            topk = torch.topk(logits_token, 3, dim=-1)
            # topk: (bsz, 3)
            # print("topk", topk[1][0], [self.tokenizer.decode([x]) for x in topk[1][0].tolist()])
            add_set = set()
            if left_state:
                # get the index of the last left parenthesis
                last_left_idx = torch.where(tokens[:, :cur_pos] == left_idx)[1][-1].item()

                # print("last left state", last_left_idx)

                if last_left_idx == cur_pos - 1:
                    # print("first token")
                    for o in obj_encodings:
                        if o[0] not in add_set:
                            
                            # print("add logits of", o[0], self.tokenizer.decode([o[0]]))
                            add_set.add(o[0])
                            logits_token[:, o[0]] += 10

                else:
                    for o in obj_encodings:
                        if len(o) >= cur_pos - last_left_idx:
                            if (o[cur_pos-last_left_idx-1] not in add_set) and torch.all(tokens[0, last_left_idx+1:cur_pos] == torch.tensor(o[:cur_pos-last_left_idx-1]).cuda()):
                                print("matching", self.tokenizer.decode(o[:cur_pos-last_left_idx-1]))
                                print("add logits of", self.tokenizer.decode([o[cur_pos-last_left_idx-1]]))
                                logits_token[:, o[cur_pos-last_left_idx-1]] += 10
                                add_set.add(o[cur_pos-last_left_idx-1])
                            
                        elif len(o) == cur_pos - last_left_idx - 1 and right_idx not in add_set:
                            if torch.all(tokens[0, last_left_idx+1:cur_pos] == torch.tensor(o).cuda()):
                                logits_token[:, right_idx] += 10
                                add_set.add(right_idx)


            topk = torch.topk(logits_token, 3, dim=-1)
            # topk: (bsz, 3)
            # print("after-topk", topk[1][0], [self.tokenizer.decode([x]) for x in topk[1][0].tolist()])


            for i, func in enumerate(disable_func):
                func_id = self.func_dict[func]
                logits_func[:, func_id] = -1e5
            logits_func += self.logits_bias
            logits = torch.cat([logits_token, logits_func], dim=-1) # (bsz, vocab_size + len(func_list))
            if temperature > 0:
                probs = torch.softmax(logits / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits, dim=-1)
            next_token = next_token.reshape(-1)
            # only replace token if the prompt is ended
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            if return_top > 0:
                generation_log.append(
                    (next_token[0].item(), [(i.item(), logits[0, i.item()].item()) for i in torch.argsort(logits[0, :], descending=True)[:return_top]])
                )
            tokens[:, cur_pos] = next_token
            prev_pos = cur_pos
            if next_token[0] == left_idx:
                left_state = True
            if next_token[0] == right_idx:
                left_state = False
            # print("cur_pos:", cur_pos, "token", next_token[0].item(), self.tokenizer.decode(next_token[0].item()))
            # print("left state", left_state)

            # print("next token:", next_token[0], stop_token_single)
            if next_token[0] >= 32000 or next_token[0] in stop_token_single:
                # print("breaking!!")
                break

            if any([torch.equal(tokens[0, cur_pos - len(substr) + 1: cur_pos + 1], substr) for substr in stop_token_substr]):
                break



        # concat_h = torch.cat(hs, dim=1)
        decoded = []
        for i, t in enumerate(tokens.tolist()):
            # cut to max gen len
            t = t[: len(prompt_tokens[i]) + max_gen_len]
            # cut to eos tok if any
            try:
                t = t[: t.index(self.tokenizer.eos_id)]
            except ValueError:
                pass
            if t[cur_pos] >= 32000:
                if no_left_parens:
                    decoded.append(self.tokenizer.decode(t[:cur_pos]) + self.func_list[t[cur_pos] - 32000])
                else:
                    if "<" in self.func_list[0]:
                        decoded.append(self.tokenizer.decode(t[:cur_pos]) + self.func_list[t[cur_pos] - 32000] + "(")
                    elif "[" in self.func_list[0]:
                        decoded.append(self.tokenizer.decode(t[:cur_pos]) + self.func_list[t[cur_pos] - 32000] + " <")
                    else:
                        raise NotImplementedError
            else:
                decoded.append(self.tokenizer.decode(t[:cur_pos + 1]))
        '''
        with open("logs.txt", "a") as f:
            f.write(f"MODE: {self.inference_mode}\n")
            f.write(f"INPUT: {prompts[0]}\n")
            f.write(f"OUTPUT: {decoded[0]}\n")
            f.write("\n")
        '''
        if return_top > 0:
            return decoded, generation_log
        else:
            return decoded# , concat_h
    
    
def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)


    return next_token