from tqdm.auto import tqdm
from llms.proposal import ProposalModel
# from llms.proposal_embed import ProposalModel
from llms.critic import CriticModel
from utils.util import deduplication
from utils.loggers import loggers, update_log_folder, get_record_dir
from accelerate.utils import release_memory
from datasets import load_from_disk, concatenate_datasets
import os
import torch
from torch.utils.data import DataLoader
from datetime import datetime
import json
from copy import copy
import wandb


class Adapter(CriticModel):
    '''
    The Adapter is used to achieve the post-adaptation via
    - generate raw output from Proposal LLM Model
    - evaluate the output using Critic LLM Model
    - use the eval scores and MDP search to get the final output
    '''
    def __init__(self, prompt, config, accelerator):
        self.config = config
        self.prompt = prompt
        self.accelerator = accelerator

        # proposal
        if not self.config["data_path"]:
            self.proposal = ProposalModel(config)

        # critic
        super().__init__(config, accelerator)
        self.acc_table = wandb.Table(columns=["stage", "accuracy"])
        self.get_accuracy = None

        self.accelerator.init_trackers(
            project_name=self.config['task'] + '+' + self.config['wandb_project'], 
            config=self.config | {"learning_rate": self.optimizer.param_groups[0]["lr"]},
            init_kwargs={
                "wandb": {
                    "save_code": True, 
                    "name": datetime.now().strftime("%Y%m%d-%H%M%S"),
                    "group": self.config['wandb_group'],    
                }
            }
        )

    def generate_negative(self, question, input_string, search_step=None):
        '''
        Used to complement the negative thought process given prompt and partial CoT
        input: 
        [prompt]\nQ: xxx\n\nA: xxx"
        '''
        prompt = f"{self.prompt}\n{question}".strip() if self.prompt else question
        max_tokens_cot = self.config.get("max_tokens_cot", 25600)
        search_step = search_step if self.config["gsm8k"] else None
        generated_texts = self.proposal.get_response(prompt, input_string, n=1, max_tokens=max_tokens_cot, extract_first_sentence=False, search_step=search_step)
        return generated_texts

    def generate(self, question, input_string, search_step=0, batch_idx=0):
        '''
        Used to generate left thought process and final answer given prompt and partial CoT
        Proposal is used to generate thought, while Critic is used to evaluate the thought
        input: 
        [prompt]\nQ: xxx\n\nA: xxx"
        '''
        prompt = f"{self.prompt}\n{question}".strip() if self.prompt else question
        # Step-level Search
        if self.config["max_length"] > 1:
            generated_texts = self.proposal.get_response(prompt, input_string, max_tokens=self.config["max_tokens"], extract_first_sentence=True, search_step=search_step)
        # Greedy Search
        else:
            generated_texts = self.proposal.get_response(prompt, input_string, stop=["\n\n"], max_tokens=512, extract_first_sentence=False)

        if generated_texts == '<SKIP>' or len(generated_texts) < 1:
            return '<SKIP>'

        generated_texts = [t.strip() for t in generated_texts if t.strip() != '.']

        if self.config.get("load_semantic_model", None):
            generated_texts = deduplication(generated_texts, num_to_keep=self.config['beam_size'], fill_to=self.config['num_candidates'], semantic_model=self.semantic_model, semantic_tokenizer=self.semantic_tokenizer, context=input_string)
        else:
            generated_texts = deduplication(generated_texts, num_to_keep=self.config['beam_size'], fill_to=self.config['num_candidates'])

        # Get scores from Critic
        # texts_to_score = [prompt + input_string + t.strip() for t in generated_texts]
        texts_to_score = [prompt + input_string.replace('<|im_start|>', '').replace('<|im_end|>', '') + t.strip().replace('<|im_start|>', '').replace('<|im_end|>', '') for t in generated_texts]
        # texts_to_score = [reward_prompt_template.format(Q=question, A=t.strip().replace('<|im_start|>', '').replace('<|im_end|>', '')) for t in generated_texts]
        # breakpoint()
        scores = self.get_scores_from_texts(texts_to_score)
        embeddings = self.get_embeddings_from_texts(texts_to_score)


        # find <EMPTY> in generated_texts and make the corresponding scores -inf
        for i, t in enumerate(generated_texts):
            if t == "<EMPTY>":
                scores[i] = -float('inf')

        pr_text = '\n'.join([r + f' (Score: {s})' for r, s in zip(generated_texts, scores.tolist())])
        loggers["proposal"].info(f"\nQuestion: {batch_idx}\n{'='*20}\nQuery:\n{prompt + input_string}\nResponses:\n{pr_text}")
        
        ad_text = '\n\n'.join([f"{r} (Score: {s})" for r, s in zip(texts_to_score, scores.tolist())])
        loggers["critic"].info(f"\nQuestion: {batch_idx}\n{'='*20}\n{ad_text}")
        
        return {
            "text": generated_texts,
            "scores": scores,
            "embeddings": embeddings,
        }

    def evaluate(self, eval_dataset, use_adapter=True, stage_name=""):
        num_online_eval_size = self.config.get("num_online_eval_size", len(eval_dataset))
        eval_dataloader = DataLoader(eval_dataset, batch_size=num_online_eval_size, shuffle=False, collate_fn=lambda x: x)
        progress_bar = tqdm(total=len(eval_dataloader), desc="Evaluating", disable=not self.accelerator.is_local_main_process)
        overall_results = dict(completions=[], ground_truths=[], rounds=[])
        record_results = []
        for batch in eval_dataloader:
            results = self.evaluate_batch(batch, use_adapter, stage_name)
            if self.accelerator.is_main_process:
                for key in overall_results:
                    overall_results[key].extend(results[key])
                # print('have a look'); import IPython; IPython.embed()
                for data, completion in zip(results['ground_truths'], results['completions']):
                    record_results.append({**data, 'completion': completion})
        if self.accelerator.is_main_process:
            accuracy, std = self.get_accuracy(overall_results)
            saved_result = {'acc':accuracy, 'std':std, 'record': record_results} #'debug': overall_results
            with open(os.path.join(get_record_dir(), 'result.json'), 'w') as f:
                json.dump(saved_result, f, indent=2)
            print(f"\nStage: {stage_name}, Accuracy: {accuracy * 100:.2f}% ± {std * 100:.2f}%")
            data = [stage_name] + [accuracy]
            self.acc_table.add_data(*data)
            self.accelerator.log({"accuracy": copy(self.acc_table)})
            progress_bar.update(1)

    def train(self, train_dataset, test_dataset):
        def _collect_data(data_dir):
            dataset_list = []
            for root, dirs, files in os.walk(data_dir):
                for dir_name in dirs:
                    if dir_name.startswith("2025") and 'traindpo' not in root: # suppose trained in year 2025
                        dataset = load_from_disk(os.path.join(root, dir_name))
                        dataset_list.append(dataset)
            return concatenate_datasets(dataset_list)
        
        torch.cuda.empty_cache()
        self.accelerator.free_memory()

        # get parameters
        num_epochs = self.config['num_epochs']
        use_blackbox_warmup = self.config.get("use_blackbox_warmup", False)
        num_epochs_blackbox_warmup = self.config.get("num_epochs_blackbox_warmup", 1)
        num_online_finetuning_repeat = self.config.get("num_online_finetuning_repeat", 1)
        save_critic_model = self.config.get("save_critic_model", None)
        skip_eval = self.config.get("skip_eval", False)
        # dpo_iteration = self.config.get("dpo_iteration", 1)
        samestep_flag = self.config.get("use_samestep_supervision", False)

        # split train-validation dataset
        validation_ratio = self.config.get("validation_ratio", 0.)
        if validation_ratio > 0 and test_dataset is None:
            train_dataset, eval_dataset = train_dataset.train_test_split(test_size=validation_ratio, shuffle=False).values()
        else: 
            eval_dataset = test_dataset

        self.accelerator.print(f"Train size: {len(train_dataset)}, eval size: {len(eval_dataset)}")
        num_online_dataloader_size = self.config.get("num_online_dataloader_size", len(train_dataset))
        train_dataloader = DataLoader(train_dataset, batch_size=num_online_dataloader_size, shuffle=True, collate_fn=lambda x: x)
        self.optimizer, self.lr_scheduler = self.accelerator.prepare(
             self.optimizer, self.lr_scheduler
            )
        save_epoch = self.config.get("num_online_save_epoch", len(train_dataloader))
        eval_epoch = self.config.get("num_online_eval_epoch", len(train_dataloader))

        # Validate on test set using solely blackbox
        if self.config["eval_blackbox"]:
            update_log_folder(
                f"{self.config['search_method']}/StepWise={self.config['use_stepwise_supervision']}/Critic={os.path.basename(self.config['load_critic_model'])}/eval_blackbox_only", 
                self.accelerator.process_index
                )
            self.evaluate(
                eval_dataset, 
                use_adapter=False,
                stage_name="Blackbox only"
            )
            self.accelerator.wait_for_everyone()

        # Validate on test set using blackbox and unfinetuned adapter
        if self.config["eval_unfinetuned"]:
            update_log_folder(
                f"{self.config['search_method']}/StepWise={self.config['use_stepwise_supervision']}/Critic={os.path.basename(self.config['load_critic_model'])}/eval_raw_adapter", 
                self.accelerator.process_index
                )
            self.evaluate(
                eval_dataset, 
                use_adapter=True,
                stage_name="Raw adapter"
            )
            self.accelerator.wait_for_everyone()

        if not self.config["do_train"]:
            exit(0)

        # directly do finetuning using collected data
        now_time = datetime.now().strftime('%Y%m%d-%H%M')
        if self.config["data_path"]:
            update_log_folder(
                f"{self.config['search_method']}/StepWise={self.config['use_stepwise_supervision']}/direct_finetuning/{now_time}", 
                self.accelerator.process_index
                )
            dataset_path = self.config["data_path"]
            with self.accelerator.main_process_first():
                train_dataset = _collect_data(dataset_path)
                # train_dataset = load_from_disk(dataset_path)
            self.accelerator.print(f"Train size: {len(train_dataset)}")
            train_loader = self.build_dataloader(train_dataset)
            for _ in range(num_online_finetuning_repeat):
                self.train_step(train_loader)
            if not skip_eval:
                self.evaluate(
                    eval_dataset, 
                    use_adapter=True,
                    stage_name="Direct finetuning"
                )
            if save_critic_model :
                save_path = f"{save_critic_model}/direct_finetuning"
                if self.accelerator.is_main_process:
                    os.makedirs(save_path, exist_ok=True)
                    print(f'Saving model to {save_path}')
                    unwrapped_model = self.accelerator.unwrap_model(self.model)
                    unwrapped_model.save_pretrained(save_path)
                    self.tokenizer.save_pretrained(save_path)
                self.accelerator.wait_for_everyone()
        else:
            # We ignore offline warmup for the current stage, only foucs on online finetuning
            model = 'llama3-8b-instruct' if 'llama' in self.config['whitebox'].lower() else 'qwen2.5-7b-instruct'
            progress_bar = tqdm(total=num_epochs*len(train_dataloader), desc="Training", disable=not self.accelerator.is_local_main_process)
            for epoch in range(num_epochs):
                for idx, batch in enumerate(train_dataloader):
                    update_log_folder(
                        f"{self.config['search_method']}/StepWise={self.config['use_stepwise_supervision']}/finetuning_epoch_{epoch}_idx_{idx}", 
                        self.accelerator.process_index
                        )
                    if self.accelerator.is_main_process:
                        loggers["train"].info(f"\n{'='*20}\nepoch {epoch} | idx {idx}")
                        loggers["eval"].info(f"\n{'='*20}\nepoch {epoch} | idx {idx}")
                    loggers["search"].info(f"\n{'='*20}\nepoch {epoch} | idx {idx}")
                    loggers["tensor"].info(f"\n{'='*20}\nepoch {epoch} | idx {idx}")
                    loggers["proposal"].info(f"\n{'='*20}\nepoch {epoch} | idx {idx}")

                    
                    if self.config["train_split"] is not None:
                        dpo_dataset_path = f"{self.config['save_data_path']}/{self.config['task']}/{model}/use_adapter=<adapter>/traindpo_epoch_{epoch}_idx_{idx}/split_{self.config['train_split']}/{now_time}/train_dpo.jsonl"
                        dataset_path = f"{self.config['save_data_path']}/{self.config['task']}/{model}/use_adapter=<adapter>/train_epoch_{epoch}_idx_{idx}/split_{self.config['train_split']}/{now_time}"
                    else:
                        # offline_iter2_initialstart
                        dpo_dataset_path = f"{self.config['save_data_path']}/{self.config['task']}/{model}/use_adapter=<adapter>/traindpo_epoch_{epoch}_idx_{idx}/{now_time}/train_dpo.jsonl"
                        dataset_path = f"{self.config['save_data_path']}/{self.config['task']}/{model}/use_adapter=<adapter>/train_epoch_{epoch}_idx_{idx}/{now_time}"

                    if epoch < num_epochs_blackbox_warmup and use_blackbox_warmup:
                        # Collect data using solely blackbox
                        dataset_path = dataset_path.replace('<adapter>', 'False')
                        dpo_dataset_path = dpo_dataset_path.replace('<adapter>', 'False')
                        self.prepare_for_critic(batch, dataset_path, use_adapter=False)
                    else:
                        # Collect data using blackbox and adapter
                        dataset_path = dataset_path.replace('<adapter>', 'True')
                        dpo_dataset_path = dpo_dataset_path.replace('<adapter>', 'True')
                        self.prepare_for_both(batch, dpo_dataset_path, dataset_path, use_adapter=True)
                    progress_bar.update(1)
                    continue
                    
        self.accelerator.wait_for_everyone()     
        self.accelerator.end_training()

    def prepare_for_critic(self):
        raise NotImplementedError

    def prepare_for_proposal_with_reflection(self):
        raise NotImplementedError
    
    def prepare_for_proposal(self):
        raise NotImplementedError
    
    def prepare_for_both(self):
        raise NotImplementedError
    
    def evaluate_batch(self):
        raise NotImplementedError