from algo.adapter import Adapter
from utils.util import accumulate_strings, formulate_string
from utils.loggers import loggers, get_debug_dir
import wandb
from concurrent.futures import TimeoutError
from algo.beam_new_search import Beam_New_Search
from algo.linear_backup import Next_Linear_Search
from copy import copy
from accelerate.utils import gather_object
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
import random
import os
from openai import AzureOpenAI
import json
import concurrent.futures
import torch
from algo.task_adapters.gpt_sampler import ChatCompletionSampler
import logging


class Reasoning_Adapter(Adapter):
    def __init__(self, prompt, config, accelerator):
        self.config = config
        self.prompt = prompt
        self.accelerator = accelerator

        if os.getenv("PROCESSOR", "") == "gpt-4o-mini":
            self.sampler = ChatCompletionSampler(model="gpt-4o-mini")
        else:
            print(f"Unknown processor: {os.getenv('PROCESSOR')}; set 'PROCESSOR=gpt-4o-mini' and 'OPENAI_API_KEY=YOUR_KEY' for best results.")
            raise ValueError(f"MATH requires PROCESSOR atm. AIME is fine without it.")


        self.stop_criterion = None
        self.get_accuracy = None
        self.is_correct = None
        self.qa_template = None

        super().__init__(
                config=config,
                prompt=prompt,
                accelerator=accelerator
            )

    def get_ans_from_proposal(self, q, n=1, temp=1):
        qa_text = self.formulate_qa(q=q, a="")
        prompt = f"{self.prompt}\n{qa_text}" if self.prompt else qa_text 
        max_tokens_cot = self.config.get("max_tokens_cot", 512)
        partial_steps = ''
        ans = self.proposal.get_response(prompt=prompt, partial_steps=partial_steps, n=n, stop=["\n\n"], max_tokens=max_tokens_cot, extract_first_sentence=False, temp=temp)
        loggers['proposal'].info(f"\n{'='*20}\nQuery:\n{prompt}\nResponses:\n{ans}")
        return ans
    
    def prepare_for_critic(self, batch, dataset_path="", use_adapter=False):
        assert not use_adapter, "Only used for warmup reward model training, Adapter is not used in this task"
        list_idx = list(range(len(batch)))
        progress_bar = tqdm(total=len(list_idx), desc="prepare", disable=not self.accelerator.is_local_main_process)

        def process_batch_item(args):
            idx, batch = args
            b = batch[idx]
            ground_truth = self.extract_ground_truth(b)
            question = self.formulate_question(b)
            negative_ans = self.get_ans_from_proposal(question, n=self.config["num_candidates_blackbox_warmup"])
            if negative_ans is None:
                return None
            negative_ans = list(set(negative_ans))
            negative_ans = [ans.strip() for ans in negative_ans]
            progress_bar.update(self.accelerator.num_processes)
            return negative_ans, ground_truth, b
        
        threadpool = self.config.get("threadpool", False)
        assert threadpool, "Threadpool is required for this task"
        self.accelerator.wait_for_everyone()
        with self.accelerator.split_between_processes(list_idx) as batch_idx:
            results = dict(negative_texts=[], positive_texts=[])
            data_lists = [
                (idx, batch) for idx in batch_idx
            ]

            warmup = torch.inverse(torch.ones((1, 1), device="cuda:0"))
            with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                for result in executor.map(process_batch_item, data_lists):
                    if result:
                        completions, ground_truth, b = result

                        positive_texts, negative_texts = [], []
                        if self.config["use_stepwise_supervision"]:
                            for c in completions: # c is a whole sentence
                                text = self.formulate_qa(q=self.formulate_question(b), a=c)
                                partial_texts = [self.formulate_qa(q=self.formulate_question(b), a=step) for step in accumulate_strings([c], gsm8k=self.config["gsm8k"])]
                                true_flag = self.is_correct(text, ground_truth) if self.config["gsm8k"] else self.is_correct(text, ground_truth["answer"], self.sampler)
                                # breakpoint()
                                if true_flag and self.config["use_outcome_supervision"]:
                                    positive_texts.extend(partial_texts)
                                else:
                                    negative_texts.extend(partial_texts)
                        else:
                            neg_texts = [self.formulate_qa(q=self.formulate_question(b), a=ans) for ans in accumulate_strings(completions, gsm8k=self.config["gsm8k"])]
                            for text in neg_texts:
                                true_flag = self.is_correct(text, ground_truth) if self.config["gsm8k"] else self.is_correct(text, ground_truth["answer"], self.sampler)
                                if true_flag and self.config["use_outcome_supervision"]:
                                    positive_texts.append(text)
                                else:
                                    negative_texts.append(text)

                        # deduplication
                        positive_texts = list(set(positive_texts))
                        negative_texts = list(set(negative_texts))
                        negative_texts = [text for text in negative_texts if text not in positive_texts]

                        results["positive_texts"].extend(positive_texts)
                        results["negative_texts"].extend(negative_texts)

            results = [results]
            
        gathered_results = gather_object(results)

        if self.accelerator.is_main_process:  
            # Adapter model data
            positive_texts = []
            negative_texts = []
            
            for b in batch:
                question = self.formulate_question(b)
                positive_ans = self.get_positive_ans(b)
                # breakpoint()
                positive_texts.extend([self.formulate_qa(q=question, a=ans) for ans in accumulate_strings(positive_ans, gsm8k=self.config["gsm8k"])])

            for result in gathered_results:
                negative_texts.extend(result["negative_texts"])
                positive_texts.extend(result["positive_texts"])

            self.build_critic_dataset(
                    positive_texts, 
                    negative_texts, 
                    save_to=dataset_path
                )

    def prepare_for_both(self, batch, dpo_dataset_path="", dataset_path="", use_adapter=False):
        '''
        Prepare data for training adapter and proposal
        '''
        list_idx = list(range(len(batch)))
        progress_bar = tqdm(total=len(list_idx), desc="prepare", disable=not self.accelerator.is_local_main_process)

        def process_batch_item(args):
            idx, batch = args
            b = batch[idx]
            ground_truth = self.extract_ground_truth(b)
            question = self.formulate_question(b)
            if use_adapter:
                if self.config["search_method"] == "beam" and self.config["mode"] == "linear":
                    beam_search = Next_Linear_Search(
                        params=self.config,
                        thought_generator=self.generate,
                        init_sequence=question,
                        stop_criterion=self.stop_criterion,
                        qa_template=self.qa_template,
                        score_func=self.get_scores_from_texts,
                        device=self.accelerator.device,
                        negative_generator=self.generate_negative,
                        batch_idx=idx,
                    )
                    res = beam_search(return_with_init=False, negative_augment=True)
                    if res is None:
                        return None
                    negative_ans, negative_augment = res
                else:
                    raise ValueError("Invalid search method")
            else:
                negative_ans = self.get_ans_from_proposal(question, n=self.config["num_candidates_blackbox_warmup"])

            if negative_ans is None:
                return None

            if use_adapter:
                negative_ans = negative_ans + negative_augment
            negative_ans = list(set(negative_ans))
            negative_ans = [ans.replace('<|im_start|>', '').replace('<|im_end|>', '').strip() for ans in negative_ans]
            progress_bar.update(self.accelerator.num_processes)
            return negative_ans, ground_truth, b
        
        threadpool = self.config.get("threadpool", False)
        assert threadpool, "Threadpool is required for this task"
        self.accelerator.wait_for_everyone()
        with self.accelerator.split_between_processes(list_idx) as batch_idx:
            results = dict(negative_texts=[], positive_texts=[]) # used for Adapter ranking-based NCE loss training, accumulated step level
            dpo_results = []
            data_lists = [
                (idx, batch) for idx in batch_idx
            ]

            warmup = torch.inverse(torch.ones((1, 1), device="cuda:0"))
            with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
                for result in executor.map(process_batch_item, data_lists):
                    if result:
                        completions, ground_truth, b = result
                        # TODO: check if this check is correct
                        
                        # Proposal model data collection
                        qa_text = self.formulate_qa(q=self.formulate_question(b), a="")
                        prompt = f"{self.prompt}\n{qa_text}" if self.prompt else qa_text
                        qa_pair = dict(prompt = prompt, positive_texts=[], negative_texts=[])
                        for completion in completions:
                            true_flag = self.is_correct(completion, ground_truth) if self.config["gsm8k"] else self.is_correct(completion, ground_truth["answer"], self.sampler)
                            if true_flag and self.config["use_outcome_supervision"]:
                                completion = formulate_string(completion) if self.config["gsm8k"] else completion
                                qa_pair["positive_texts"].append(completion)
                            else:
                                qa_pair["negative_texts"].append(completion)
                        dpo_results.append(qa_pair)

                        # Adapter model data collection
                        positive_texts, negative_texts = [], []
                        if self.config["use_stepwise_supervision"]:
                            for c in completions: # c is a whole sentence
                                text = self.formulate_qa(q=self.formulate_question(b), a=c)
                                partial_texts = [self.formulate_qa(q=self.formulate_question(b), a=step) for step in accumulate_strings([c], gsm8k=self.config["gsm8k"])]
                                true_flag = self.is_correct(text, ground_truth) if self.config["gsm8k"] else self.is_correct(text, ground_truth["answer"], self.sampler)
                                if true_flag and self.config["use_outcome_supervision"]:
                                    positive_texts.extend(partial_texts)
                                else:
                                    negative_texts.extend(partial_texts)
                        else:
                            neg_texts = [self.formulate_qa(q=self.formulate_question(b), a=ans) for ans in accumulate_strings(completions, gsm8k=self.config["gsm8k"])]
                            for text in neg_texts:
                                true_flag = self.is_correct(text, ground_truth) if self.config["gsm8k"] else self.is_correct(text, ground_truth["answer"], self.sampler)
                                if true_flag and self.config["use_outcome_supervision"]:
                                    positive_texts.append(text)
                                else:
                                    negative_texts.append(text)
                        # deduplication
                        positive_texts = list(set(positive_texts))
                        negative_texts = list(set(negative_texts))
                        negative_texts = [text for text in negative_texts if text not in positive_texts]

                        results["positive_texts"].extend(positive_texts)
                        results["negative_texts"].extend(negative_texts)

            results = [results]
            
        gathered_results = gather_object(results)
        gathered_dpo_results = gather_object(dpo_results)
        
        # Add ground truth as positive data to dataset and construct it
        if self.accelerator.is_main_process:  
            # Adapter model data
            positive_texts = []
            negative_texts = []
            qa_pairs = []
            
            for b in batch:
                question = self.formulate_question(b)
                positive_ans = self.get_positive_ans(b)
                positive_texts.extend([self.formulate_qa(q=question, a=ans) for ans in accumulate_strings(positive_ans, gsm8k=self.config["gsm8k"])])

                for result in gathered_dpo_results:
                    prompt = f"{self.prompt}\n{self.formulate_qa(q=question, a='')}" if self.prompt else self.formulate_qa(q=question, a='')

                    if prompt == result['prompt']:
                        qa_pairs.append(result)
                        qa_pairs[-1]["positive_texts"].append(positive_ans[0])
                        break

            for result in gathered_results:
                negative_texts.extend(result["negative_texts"])
                positive_texts.extend(result["positive_texts"])

            
            self.build_critic_dataset(
                    positive_texts, 
                    negative_texts, 
                    save_to=dataset_path
                )
            self.build_proposal_dataset(
                qa_pairs,
                save_to=dpo_dataset_path,
                mode="sft"
            )
            self.build_proposal_dataset(
                qa_pairs,
                save_to=dpo_dataset_path,
                mode="dpo"
            )

    def log_accuracy(self, results):
        results = [results]
        self.accelerator.wait_for_everyone()
        gathered_results = gather_object(results)
        if self.accelerator.is_main_process:  
            results = dict(completions=[], ground_truths=[], rounds=[])
            for result in gathered_results:
                results["completions"].extend(result["completions"])
                results["ground_truths"].extend(result["ground_truths"])
                results["rounds"].extend(result["rounds"])
            accuracy, _ = self.get_accuracy(results)
            # print(f"\nAccuracy: {accuracy * 100:.2f}%")
            loggers["accuracy"].info(f"\nAccuracy: {accuracy * 100:.2f}%\n")


    def evaluate_batch(self, eval_dataset, use_adapter=True, stage_name=""):
        split_dict = {
            "list_idx": list(range(len(eval_dataset))) * self.config["num_eval_rounds"],
            "round_idx": [i for i in range(self.config["num_eval_rounds"]) for _ in range(len(eval_dataset))]
        }
        progress_bar = tqdm(total=len(split_dict["list_idx"]), desc=stage_name, disable=not self.accelerator.is_local_main_process)

        def process_batch_item(args):
            idx, round, eval_dataset = args
            b = eval_dataset[idx]
            ground_truth = self.extract_ground_truth(b)
            question = self.formulate_question(b)
            if use_adapter:
                if self.config["search_method"] == "beam":
                    if self.config["mode"] == "new":
                        beam_search = Beam_New_Search(
                            params=self.config,
                            thought_generator=self.generate,
                            init_sequence=question,
                            stop_criterion=self.stop_criterion,
                            qa_template=self.qa_template,
                            score_func=self.get_scores_from_texts,
                            batch_idx=idx,
                        )
                    elif self.config["mode"] == "linear":
                        beam_search = Next_Linear_Search(
                            params=self.config,
                            thought_generator=self.generate,
                            init_sequence=question,
                            stop_criterion=self.stop_criterion,
                            qa_template=self.qa_template,
                            score_func=self.get_scores_from_texts,
                            device=self.accelerator.device,
                            negative_generator=self.generate_negative,
                            batch_idx=idx,
                        )
                    answer = beam_search(return_with_init=False)
                else:
                    raise ValueError("Invalid search method")
            else:
                answer = self.get_ans_from_proposal(q=question, n=1, temp=self.config["temperature"])
            
            progress_bar.update(self.accelerator.num_processes)
            if answer is None or len(answer) == 0:
                return None

            return answer, ground_truth, round
        
        threadpool = self.config.get("threadpool", False)
        self.accelerator.wait_for_everyone()
        with self.accelerator.split_between_processes(split_dict) as splitted_dict:
            results = dict(completions=[], ground_truths=[], rounds=[])

            data_lists = [
                (idx, round, eval_dataset) for idx, round in zip(splitted_dict["list_idx"], splitted_dict["round_idx"])
            ]
            if threadpool:
                warmup = torch.inverse(torch.ones((1, 1), device="cuda:0"))
                with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
                    for result in executor.map(process_batch_item, data_lists):
                        if result:
                            completion, ground_truth, round = result
                            results["completions"].append(completion[0])
                            results["ground_truths"].append(ground_truth)
                            results["rounds"].append(round)
            else:
                for idx, round in zip(splitted_dict["list_idx"], splitted_dict["round_idx"]):
                    result = process_batch_item((idx, round, eval_dataset))
                    if result:
                        completion, ground_truth, round = result
                        results["completions"].append(completion[0])
                        results["ground_truths"].append(ground_truth)
                        results["rounds"].append(round)
                    else:
                        results["completions"].append("Unable to solve this question")
                        results["ground_truths"].append(self.extract_ground_truth(eval_dataset[idx]))
                        results["rounds"].append(round)

            results = [results]
        
        self.accelerator.wait_for_everyone()
        gathered_results = gather_object(results)
        
        if self.accelerator.is_main_process:  
            results = dict(completions=[], ground_truths=[], rounds=[])
            for result in gathered_results:
                results["completions"].extend(result["completions"])
                results["ground_truths"].extend(result["ground_truths"])
                results["rounds"].extend(result["rounds"])
            # accuracy, std = self.get_accuracy(results)
        else:
            results =  None
        self.accelerator.wait_for_everyone()
        return results

    def formulate_qa(self, q, a):
        return self.qa_template.replace("<Q>", q).replace("<A>", a)


    def get_positive_ans(self, b):
        pass
    
    
    def formulate_question(self, b):
        pass


    def extract_ground_truth(self, b):
        pass