from utils import *
from prompts import *
from collections import defaultdict
import math
import numpy as np
import re
import argparse
import time

import torch
import argparse

from accelerate import init_empty_weights, infer_auto_device_map
from transformers import AutoConfig, LlamaTokenizer
from transformers import AutoModelForCausalLM, GenerationConfig



class Agent():
    
    def __init__(self, OPENAI_API_KEY):
        self.model_name = "/export/share/ruimeng/ckpts/llm/llama_hf/30B"
        self.tokenizer = LlamaTokenizer.from_pretrained(self.model_name)
        self.device_map = "auto"
        # self.device_map = self.get_device_map(self.model_name_or_path, "a100-40g", False)
        print('Loading model:\n')
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            device_map=self.device_map,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=self.device_map is not None,
            load_in_8bit=False,
            )
        self.generation_config = GenerationConfig(
            do_sample=True,
            temperature=0.0000000000000001,
            top_p=0.75,
            top_k=1,
            max_new_tokens=500,
        )

    def predict_answer(self, user_message, temperature=0.0):
        with torch.no_grad():
            inputs = tokenizer(
                f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
                ### Instruction: {user_message}
                ### Response:""",
                return_tensors="pt",
            ).input_ids
            input_ids = input_ids.to(0)
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=torch.ones_like(input_ids),
                generation_config=self.generation_config,
            )
        return tokenizer.batch_decode(generated_ids.cpu(), skip_special_tokens=True)[0].split('Response:')[1].strip().lower()


    def get_device_map(self, model_name, device, do_int8):
        if device == "a100-40g":
            return "auto"
    
        with init_empty_weights():
            config = AutoConfig.from_pretrained(model_name)
            model = AutoModelForCausalLM.from_config(config)
    
        d = {0: "18GiB"}
        for i in range(1, 6):
            d[i] = "26GiB"
        device_map = infer_auto_device_map(
            model, max_memory=d, dtype=torch.int8 if do_int8 else torch.float16,
            no_split_module_classes=["BloomBlock", "OPTDecoderLayer", "LLaMADecoderLayer", "LlamaDecoderLayer"]
        )
        print(device_map)
        del model
        return device_map
        



def ucb_cot(agent, state, target, participating_blocks, step_action_score, step_counter, step_action_counter, UCB_CONSTANT, step_action_ucb, grid_reward):
    
    history = step_action_score2text_llm_following(step_action_ucb)

    if history:
        prompt=prompt_with_history_v2(state, target, history)
    else:
        prompt=prompt_without_history_v2(state, target, history)

    check = True
    check_counter = 5
    local_step_action_score = {}

    while check and check_counter:
        try:
            op = agent.predict_answer(prompt, args.model_temperature)
            op_list = op.split('\n')
            idx_sequence = []
            for st in op_list:
                if 'step sequence' in st:
                    continue
                elif 'step' in st:
                    idx, val = st.split(': ')
                    idx, val = idx.strip().lower(), val.strip().lower()
                    step_action_score[idx][val] += 0.0
                    local_step_action_score[idx] = val
                    idx_sequence.append(idx)
                    step_counter[idx] += 1
                    step_action_counter[idx][val] +=1
            check = False
        except:
            check_counter -= 1
            import time
            time.sleep(3)


    step_sequence = list(local_step_action_score.values())

    for idx, step in enumerate(step_sequence):
        new_state, valid_action = add_action_to_json_state(state_text2json(state.lower().replace('.', ''), participating_blocks), step)
        step_counter[idx_sequence[idx]] += 1
        step_action_counter[idx_sequence[idx]][step] += 1
        
        if not valid_action:
            for stepi in local_step_action_score:
                step_action_score[stepi][local_step_action_score[stepi]] += 0.0
                step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])
            return step_action_score, step_action_ucb, step_counter, step_action_counter
        else:
            if new_state == state_text2json(target.lower().replace('.', ''), participating_blocks):
                for jdx, jval in enumerate(step_sequence[:idx+1]):
                    step_action_score[idx_sequence[jdx]][jval] += float(grid_reward)
                    step_action_ucb[idx_sequence[jdx]][jval] = get_ucb_score(step_action_score[idx_sequence[jdx]][jval], UCB_CONSTANT, step_counter[idx_sequence[jdx]], step_action_counter[idx_sequence[jdx]][jval])
                return step_action_score, step_action_ucb, step_counter, step_action_counter
        state = state_json2text(new_state)

    if state == state_text2json(target.lower().replace('.', ''), participating_blocks):
        for stepi in local_step_action_score:
            step_action_score[stepi][local_step_action_score[stepi]] += float(grid_reward)
            step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])
        return step_action_score, step_action_ucb, step_counter, step_action_counter
    else:
        for stepi in local_step_action_score:
            step_action_score[stepi][local_step_action_score[stepi]] += 0.0
            step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])
        return step_action_score, step_action_ucb, step_counter, step_action_counter




def main(args):
    UCB_CONSTANT = args.exploration_constant
    grid_reward = args.reward

    comp_answer_steps = args.no_of_answer_steps.split(',')

    for NO_OF_STEPS_IN_ANSWER in comp_answer_steps:
        print('#'*50)
        print('No of steps in ans: ', NO_OF_STEPS_IN_ANSWER)
        print('#'*50)

        bw_data = get_blocksworld_data(int(NO_OF_STEPS_IN_ANSWER))
        
        for _ in range(args.no_of_trials):
            
            preds = []
            avg_actions = []
            
            for item_idx, item in tqdm(enumerate(bw_data)):
                step_action_score = defaultdict(lambda: defaultdict(float))
                step_action_ucb = defaultdict(lambda: defaultdict(float))
                step_counter = defaultdict(int)
                step_action_counter = defaultdict(lambda: defaultdict(int))

                # Initialize Agent
                agent = Agent(args.OPENAI_API_KEY)
                
                init_block_config = state_text2json(item['real_problem'], item['participating_blocks'])
                final_block_config = copy.deepcopy(init_block_config)
                gt_action_sequence = real_solution2text(item['real_solution'])
                
                for action in gt_action_sequence:
                    final_block_config, valid_action = add_action_to_json_state(final_block_config, action)
                    assert valid_action is True, 'Cannot reach final block config'
    
    
                # Learn: run iterations
                for pq in range(args.no_of_passes):
                    step_action_score, step_action_ucb, step_counter, step_action_counter = ucb_cot(
                        agent, state_json2text(init_block_config), 
                        state_json2text(final_block_config), 
                        item['participating_blocks'], 
                        step_action_score,
                        step_counter,
                        step_action_counter,
                        UCB_CONSTANT,
                        step_action_ucb,
                        grid_reward
                    )    
    
                final_steps = []
                
                for stepi in step_action_score:
                    best_action, best_score = sorted(step_action_score[stepi].items(), key=lambda x: x[1], reverse=True)[0]
                    if best_score != 0.0:
                        final_steps.append(best_action)
                    else:
                        break

                # If none of the generated solutions are correct then step_action_score's best_score will always be zero
                # Hence a non-empty final_steps indicates proposed solution is correct
                if final_steps:
                    preds.append(1)
                print(str(item_idx+1)+":\t"+str(sum(preds))+"\n")
            
            print('No of questions: ', str(len(bw_data)))
            print('No of correct answers: ', sum(preds))
        print()

    
    


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-no_of_passes', default=10, type=int)
    parser.add_argument('-no_of_trials', default=1, type=int)
    parser.add_argument('-reward', default=1, type=int)
    parser.add_argument('-exploration_constant', default=10, type=int)
    parser.add_argument('-model_temperature', default=0.0, type=float)
    parser.add_argument('-OPENAI_API_KEY', default="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")
    parser.add_argument('-no_of_answer_steps', default='2,4,6')
    args = parser.parse_args()
    
    main(args)