import sys
import os
import transformers
import transformers.generation
from transformers import AutoTokenizer,AutoModelForCausalLM
import torch
from transformers.generation.stopping_criteria import StoppingCriteria,StoppingCriteriaList,STOPPING_CRITERIA_INPUTS_DOCSTRING,add_start_docstrings
import subprocess
import time
import argparse
import gym
import babyai_text
import numpy as np
import json
from peft import PeftModel
import random

judge_prefix = 'STATUS: '
succ_word = 'GOOD'
fail_word = 'BAD'

succ_sentence = judge_prefix + succ_word
fail_sentence = judge_prefix + fail_word

prompt_header_actor = 'Assume that you are an agent in a Grid World. Given a goal, your task is to execute a sequence of actions to achieve the goal.\nPossible action of the agent: turn left, turn right, go forward, pick up, drop, toggle.\n\n'
prompt_header_critic = "Given you a partial trajectory interacting with a Grid World. Your task is to determine whether the last step is GOOD, BAD or UNKNOWN to achieve the final goal. Please reasoning step by step and end your response with 'This step is GOOD.', 'This step is BAD.' or 'This step is UNKNOWN.'.\n\n"
prompt_header_future = "Given you a partial trajectory interacting with a Grid World. Your task is to predict future outcomes.\n\n"
prompt_header_experience = 'Given you a trajectory interacting with a Grid World, your task is to predict whether current trajectory will fail or succeed.'+'\nHere is the task you need to predict.\n'
action_space = ["turn left","turn right","go forward","pick up","drop","toggle"]
task_types = ['goto','pickup','gotoafterpickup', 'pickupthengoto', 'putnextto', 'opendoor']


def get_valid_actions(obs):
    if_obj_forward = False
    if_wall_forward = False
    if_door_forward = False
    if_hold_obj = False
    for o in obs:
        if 'and' not in o and o.startswith('You see') and o.endswith('1 step forward'):
            if 'wall' in o:
                if_wall_forward = True
            elif 'door' in o:
                if_door_forward = True
            else:
                if_obj_forward = True
        if 'carry' in o:
            if_hold_obj = True
    
    valid_actions = list()
    valid_actions.extend(['turn left','turn right'])
    
    if if_obj_forward and not if_hold_obj:
        valid_actions.append('pick up')
    if not if_obj_forward and not if_wall_forward and not if_door_forward:
        valid_actions.append('go forward')
        if if_hold_obj:
            valid_actions.append('drop')
    if if_door_forward:
        valid_actions.append('toggle')
    
    return valid_actions

def lora_to_base(model):
    try:
        model.base_model.disable_adapter_layers()
    except:
        print("No adapter layers to disable")
    # model.eval()
    
def base_to_lora(model):
    try:
        model.base_model.enable_adapter_layers()
    except:
        print("No adapter layers to enable")
        
def get_task_type(goal):
    if ',' in goal:
        return 'pickupthengoto'
    if len(goal.split())==12:
        assert 'go to' in goal and 'pick up' in goal
        return 'gotoafterpickup'
    if len(goal.split())==9:
        assert 'put' in goal and 'next to' in goal
        return 'putnextto'
    if len(goal.split())==5:
        if goal.startswith('pick up'):
            return 'pickup'
        if goal.startswith('go to'):
            return 'goto'
    assert 'open' in goal and 'door' in goal
    return 'opendoor'
    
class MyStopping(StoppingCriteria):
    def __init__(self,start_length, tokenizer,stop_str):
        super().__init__()
        self.tokenizer = tokenizer
        self.stop_str = stop_str
        self.start_length = start_length
    
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
        # if self.tokenizer.decode(input_ids[0].detach().cpu().numpy().tolist()).endswith(self.stop_str):
        if all([self.stop_str in s for s in self.tokenizer.batch_decode(input_ids[:, self.start_length:])]):
            return True
        return False

def get_stopping(start_length, stop_strs, tokenizer):
    stopping = StoppingCriteriaList()
    for s in stop_strs:
        stopping.append(MyStopping(start_length, tokenizer,s))
    return stopping

class MyStoppingCount(StoppingCriteria):
    def __init__(self,start_length, tokenizer,stop_str, stop_count):
        super().__init__()
        self.tokenizer = tokenizer
        self.stop_str = stop_str
        self.stop_count = stop_count
        self.start_length = start_length
    
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
        # if self.tokenizer.decode(input_ids[0].detach().cpu().numpy().tolist()).endswith(self.stop_str):
        if all([s.count(self.stop_str) >= self.stop_count for s in self.tokenizer.batch_decode(input_ids[:, self.start_length:])]):
            return True
        return False

def get_stopping_count(start_length, stop_strs, tokenizer):
    stopping = StoppingCriteriaList()
    for s,c in stop_strs:
        stopping.append(MyStoppingCount(start_length, tokenizer,s,c))
    return stopping

def generate_critic(model,tokenizer,prompt,stop=['\n']):
    torch.cuda.empty_cache()

    input_tok=tokenizer([prompt], add_special_tokens=False)
    input_ids=torch.LongTensor(input_tok['input_ids']).to(model.device)
    
    stopping = get_stopping_count(input_ids.shape[-1], [(s,1) for s in stop],tokenizer)
    try:
        outputs = model.generate(
            input_ids=input_ids,
            past_key_values=None,
            use_cache=True,
            max_new_tokens=100,
            output_scores=True,
            return_dict_in_generate=True,
            stopping_criteria=stopping,
            pad_token_id = tokenizer.pad_token_id,
            eos_token_id = tokenizer.eos_token_id,
            do_sample = False,

            top_p=1,
        )
    except:
        torch.cuda.empty_cache()
        return ''

    tokenized_texts = outputs.sequences[:,input_ids.shape[-1]:]
    output = tokenizer.batch_decode(tokenized_texts, skip_special_tokens=True)[0]

    del input_ids
    del tokenized_texts
    del outputs

    output = output.split('\n')[0].strip()
    
    torch.cuda.empty_cache()
    
    return output

def compute_output_logp_norm(model,prompt,output,tokenizer):
    torch.cuda.empty_cache()
    with torch.no_grad():
        input_tok=tokenizer([prompt],add_special_tokens=False,padding=True)
        output_tok = tokenizer([prompt+output],add_special_tokens=False,padding=True)
        # input_ids=torch.LongTensor(input_tok['input_ids']).to(model.device)
        input_ids=torch.LongTensor(input_tok['input_ids'])
        # attention_mask=torch.LongTensor(input_tok['attention_mask']).to(model.device)
        output_ids=torch.LongTensor(output_tok['input_ids']).to(model.device)
        output_attention_mask=torch.LongTensor(output_tok['attention_mask']).to(model.device)
        l1 = input_ids.shape[-1]
        # l2 = output_ids.shape[-1]
        # assert (input_ids != output_ids[:,:l1]).sum().item(////////) == 0
        res = model(output_ids[:,:-1], output_attention_mask[:,:-1])
        # logp = (torch.log(torch.softmax(res[0][0][l1-1:],dim=-1))[torch.arange(l2-l1),output_ids[0][l1:]]).mean().item()
        logp = torch.log_softmax(res.logits[0, l1-1:], dim=-1).gather(1, output_ids[:,l1:].transpose(0,1)).mean().item()
    del input_ids
    del output_ids
    del output_attention_mask
    del res
    return logp

def generate_future(model, prompt, tokenizer, predict_step):

    torch.cuda.empty_cache()

    input_tok=tokenizer([prompt], add_special_tokens=False)
    input_ids=torch.LongTensor(input_tok['input_ids']).to(model.device)
    
    # stopping = get_stopping(input_ids.shape[-1], ['\n'])
    stopping = get_stopping_count(input_ids.shape[-1], [
        ('\n', 2 + predict_step * 3),
        ('GOOD.',1),
        ('BAD.',1)],tokenizer)
    try:
        outputs = model.generate(
            input_ids=input_ids,
            past_key_values=None,
            use_cache=True,
            max_new_tokens= 1024,
            output_scores=True,
            return_dict_in_generate=True,
            stopping_criteria=stopping,
            pad_token_id = tokenizer.pad_token_id,
            eos_token_id = tokenizer.eos_token_id,
            do_sample = False,
        )
    except:
        torch.cuda.empty_cache()
        return ''

    tokenized_texts = outputs.sequences[:,input_ids.shape[-1]:]
    output = tokenizer.batch_decode(tokenized_texts, skip_special_tokens=True)[0]    
    
    output = output.strip()+'\n'

    del input_ids
    del tokenized_texts
    del outputs

    torch.cuda.empty_cache()
    
    return output

def generate_from_prompt_experience_grounded_decoding(if_peft, model, prompt, prompt_future, prompt_experience, tokenizer, stopping, predict_step = 4, max_length=256, temperature=0, top_p=1, num_beams=5, num_return_sequences = 4, valid_actions=None):
    # lora_to_base(model)
    ######################## generate action candidates ########################
    torch.cuda.empty_cache()
    input_tok=tokenizer([prompt], add_special_tokens=False,padding=True)
    input_ids=torch.LongTensor(input_tok['input_ids']).to(model.device)
    attention_mask=torch.LongTensor(input_tok['attention_mask']).to(model.device)
    generation_config = transformers.GenerationConfig(
        # temperature=temperature,
        num_beams=num_beams,
        # num_beam_groups=5,
        # diversity_penalty=1.0,
        top_p=top_p,
        # do_sample=True,
        num_return_sequences = num_return_sequences,
        output_scores=True,
        return_dict_in_generate=True,
        pad_token_id = tokenizer.pad_token_id,
        eos_token_id = tokenizer.eos_token_id,
    )
    stopping = get_stopping(input_ids.shape[-1], ['\n'], tokenizer)
    try:
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_length,
            generation_config=generation_config,
            stopping_criteria=stopping,
        )
    except:
        torch.cuda.empty_cache()
        return '',''

    tokenized_texts = outputs.sequences[:,input_ids.shape[-1]:]
    texts = tokenizer.batch_decode(tokenized_texts, skip_special_tokens=True)

    del input_ids
    del attention_mask
    del tokenized_texts
    del outputs

    torch.cuda.empty_cache()

    texts = [t.split('\n')[0].strip() + '\n' for t in texts]
    texts = list(np.unique(texts))
    old_texts = [str(t) for t in texts]
    texts = []
    for t in old_texts:
        if t.startswith('think') or t.strip() in valid_actions:
            texts.append(t)
    
    if len(texts) == 0:
        texts.append(random.choice(valid_actions) + '\n')

    scores = [compute_output_logp_norm(model, prompt, t, tokenizer) for t in texts]

    scores_final = scores
    text_chosen = texts[np.argmax(scores_final)].strip()

    msg = ''
    for i in range(len(texts)):
        msg += f'[{scores[i]:.5f}]: {texts[i]}'

    return text_chosen, msg


def run_a_task_type(if_peft,model,tokenizer,succ1,succ2,fail1,task_type,start_idx,num_task):
    env.seed(2024)
    cnt = 0
    while cnt<num_task:
        obs = env.reset()
        goal = obs[0]['mission']
        valid_actions = get_valid_actions(obs[-1]['descriptions'])
        if get_task_type(goal)!=task_type:
            # print('ERROR:',goal,task_type)
            # sys.stdout.flush()
            continue
        cnt += 1
        if cnt<start_idx:
            continue
        done = False
        reward = 0
        torch.cuda.empty_cache()

        init_prompt_actor = prompt_header_actor+'Here are two examples.\n'+succ1+'\n\n'+succ2+'\n\nHere is the task you need to complete.\n'
        if if_peft:
            init_prompt_critic = ''
            init_prompt_future = ''
        else:
            init_prompt_critic = prompt_header_critic+'Here is a successful example.\n'+succ1+'\n\nHere is a failed example.\n'+fail1+'\n\nHere is the task you need to complete.\n'
            init_prompt_future = prompt_header_future+'Here is a successful example.\n'+succ1+'\n\nHere is a failed example.\n'+fail1+'\n\nHere is the task you need to complete.\n'
        init_prompt_experience = prompt_header_experience

        prompt = 'Goal of the agent:'+goal
        prompt += '\nObservation:'+', '.join(obs[-1]['descriptions'])

        print('Goal of the agent:'+goal)
        print('Observation:'+', '.join(obs[-1]['descriptions']))
        sys.stdout.flush()

        for step in range(30):
            torch.cuda.empty_cache()
            action,msg = generate_from_prompt_experience_grounded_decoding(if_peft,model,init_prompt_actor+prompt+'\nAction:',init_prompt_future+prompt+'\nAction:',init_prompt_experience+prompt+'\nAction:',tokenizer,stopping=None,predict_step=4,max_length=128,temperature=0,top_p=1,num_beams=5,num_return_sequences=5, valid_actions=valid_actions)
            torch.cuda.empty_cache()

            if action=='' and msg=='':
                break

            if action.startswith('think'):
                new_obs = 'OK'
                done = False
                reward = 0
            else:
                a_id = action_space.index(action)
                img,reward,done,new_obs = env.step(a_id)
                valid_actions = get_valid_actions(new_obs['descriptions'])
                new_obs = ", ".join(new_obs["descriptions"])

            print(f'==>>\nAction Choosing Message:\n{msg}<<==')
            print(f'****** Execution {step}:')
            sys.stdout.flush()

            print(f'Action:{action}')
            print(f'Observation:{new_obs}')
            sys.stdout.flush()

            prompt += '\n'+f'Action:{action}'
            prompt += '\n'+f'Observation:{new_obs}'

            torch.cuda.empty_cache()
            if if_peft:
                base_to_lora(model)
            critic_info = generate_critic(model,tokenizer,init_prompt_critic+prompt+'\nCritic:')
            if if_peft:
                lora_to_base(model)
            torch.cuda.empty_cache()

            prompt += f'\nCritic:{critic_info}'
            print(f'Critic:{critic_info}')
            print('******')
            sys.stdout.flush()
            if done:
                break
        print('reward:',reward,'Done?:',done)
        print('-----------------------------------')
        print()
        sys.stdout.flush()

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path',required=True)
    parser.add_argument('--cache_dir')
    parser.add_argument('--task_type',type=str,default='goto')
    parser.add_argument('--peft_model_path',default=None)
    parser.add_argument('--num_task',type=int,default=50)
    parser.add_argument('--start_idx',type=int,default=0)

    args = parser.parse_args()

    assert args.task_type in task_types

    model = transformers.AutoModelForCausalLM.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir,torch_dtype=torch.float16,device_map='auto')
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path,padding_side='left')
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    if args.peft_model_path:
        model = PeftModel.from_pretrained(model,args.peft_model_path)
    model = model.eval()

    with open('succ_examples.json','r') as f:
        examples = json.load(f)
    with open('succ_examples_2.json','r') as f:
        examples_2 = json.load(f)
    with open('fail_examples.json','r') as f:
        fail_examples = json.load(f)

    while True:
        env = gym.make("BabyAI-MixedTestLocal-v0")
        res = env.reset()
        task_type = get_task_type(res[0]['mission'])
        if task_type==args.task_type:
            break
    
    if args.peft_model_path:
        lora_to_base(model)
    
    if args.peft_model_path:
        if_peft = True
    else:
        if_peft = False
    
    run_a_task_type(if_peft,model,tokenizer,examples[args.task_type],examples_2[args.task_type],fail_examples[args.task_type],args.task_type,args.start_idx,args.num_task)
