import os
import gc
import time
import math
import random
import pickle
import subprocess

import numpy as np

from .base import SamplingAlgorithmBase
from prover.lean.proof import ProofSummarizer
from prover.utils import ConcurrentJob


class TreeNode(object):
    def __init__(self, parent=None, code=None, **kwargs):
        self.parent = parent
        self.children = dict()
        self._info = {key: val for key, val in kwargs.items()}
        if code is not None:
            self.update_code(code)
        
        if '_discounted_rewards' not in self._info:
            self._info['_discounted_rewards'] = 0.0
        if '_discounted_visitation' not in self._info:
            self._info['_discounted_visitation'] = 0.0
        if '_subtree_discounted_rewards' not in self._info:
            self._info['_subtree_discounted_rewards'] = 0.0
        if '_subtree_discounted_visitation' not in self._info:
            self._info['_subtree_discounted_visitation'] = 0.0
        self._num_running_jobs = 0
        self._subtree_num_running_jobs = 0
        self._update_value(gamma=0.0)  # gamma=0.0 is okay for initialization
    
    # basic tree supports
    @property
    def code(self):
        return random.choice(self._info['_code_list'])

    def update_code(self, code):
        if '_code_list' not in self._info:
            self._info['_code_list'] = []
        if code not in self._info['_code_list']:
            self._info['_code_list'].append(code)
    
    def __getitem__(self, key):
        return self._info[key]
    
    def to_node_list(self):
        return sum([child.to_node_list() for _, child in self.children.items()], start=[self])
    
    def to_dict(self):
        return dict(
            info=self._info,
            children={
                edge: child_node.to_dict()
                for edge, child_node in self.children.items()
            }
        )
    
    @classmethod
    def from_dict(cls, dict_data, parent=None):
        node = cls(
            parent=parent,
            **dict_data['info'],
        )
        node.children = {
            edge: cls.from_dict(child_dict, parent=node)
            for edge, child_dict in dict_data['children'].items()
        }
        return node
    
    # algorithm supports
    def update_reward(self, reward, gamma, first_node=True):
        if first_node:
            self._info['_discounted_rewards'] = self._info['_discounted_rewards'] * gamma + reward
            self._info['_discounted_visitation'] = self._info['_discounted_visitation'] * gamma + 1.0
        self._info['_subtree_discounted_rewards'] = self._info['_subtree_discounted_rewards'] * gamma + reward
        self._info['_subtree_discounted_visitation'] = self._info['_subtree_discounted_visitation'] * gamma + 1.0
        self._update_value(gamma)
        if self.parent is not None:
            self.parent.update_reward(reward, gamma, first_node=False)
    
    def start_new_job(self, gamma, first_node=True):
        if first_node:
            self._num_running_jobs += 1
        self._subtree_num_running_jobs += 1
        self._update_value(gamma)
        if self.parent is not None:
            self.parent.start_new_job(gamma, first_node=False)
    
    def complete_job(self, gamma, first_node=True):
        if first_node:
            self._num_running_jobs -= 1
        self._subtree_num_running_jobs -= 1
        self._update_value(gamma)
        if self.parent is not None:
            self.parent.complete_job(gamma, first_node=False)
    
    def _update_value(self, gamma):
        discounted_rewards = self._info['_discounted_rewards'] * (gamma ** self._num_running_jobs)
        discounted_visitation = \
            self._info['_discounted_visitation'] * (gamma ** self._num_running_jobs) \
            + (1.0 - (gamma ** self._num_running_jobs)) / (1.0 - gamma)
        self.value = discounted_rewards / max(discounted_visitation, 1e-2)
        self.visitation = discounted_visitation

        subtree_discounted_rewards = self._info['_subtree_discounted_rewards'] * (gamma ** self._subtree_num_running_jobs)
        subtree_discounted_visitation = \
            self._info['_subtree_discounted_visitation'] * (gamma ** self._subtree_num_running_jobs) \
            + (1.0 - (gamma ** self._subtree_num_running_jobs)) / (1.0 - gamma)
        self.subtree_value = subtree_discounted_rewards / max(subtree_discounted_visitation, 1e-2)
        self.subtree_visitation = subtree_discounted_visitation


class RMaxTS(SamplingAlgorithmBase):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.gamma = self.cfg.get('gamma', 0.99)
        self.sample_num = self.cfg.get('sample_num', 6400)
        self.concurrent_num = self.cfg.get('concurrent_num', 32)
        self.tactic_state_comment = self.cfg.get('tactic_state_comment', True)
        self.ckpt_interval = self.cfg.get('ckpt_interval', 128)

        self.ckpt_filename = 'checkpoint.pkl'
        self.node_cls = TreeNode
        self.algorithm_pipeline = [
            self._tactic_tree_generate_proof,
            self._tactic_tree_parse_proof,
            self._rmax_exploration_summarize_results,
        ]
    
    # basic supports
    def _save_ckpt(self, ckpt_dict: dict):
        # save a backup before overwriting the checkpoint file
        if os.path.exists(self.ckpt_path):
            subprocess.run(['cp', self.ckpt_path, self.ckpt_path + '.backup'])
        # overwrite the checkpoint file
        with open(self.ckpt_path, 'wb') as pkl_f:
            pickle.dump(ckpt_dict, pkl_f)
    
    # tree structure supports
    def _tree_setup(self, data):
        # initialize tree
        ckpt_info = None
        for _ckpt_path in [self.ckpt_path, self.ckpt_path + '.backup']:
            if os.path.exists(_ckpt_path):
                try:
                    with open(_ckpt_path, 'rb') as pkl_f:
                        ckpt_info = pickle.load(pkl_f)
                except:
                    self.process_print(f'Checkpoint saved at {_ckpt_path} is broken.')
        if ckpt_info is not None:
            root = self.node_cls.from_dict(ckpt_info['root'])
            sample_count = ckpt_info['sample_count']
            yield_cache = ckpt_info['yield_cache']
            self.process_print(f'Load checkpoint from sample_count={sample_count}')
        else:
            root = self.node_cls(code=dict(tactic_code=str(), state_comment=str()), depth=0)
            sample_count = 0
            yield_cache = []
        
        # compile the root node with `sorry`
        self.proof_summarizer = ProofSummarizer(data=data, scheduler=self.scheduler)
        root_sorry = self.proof_summarizer.analyze('  sorry', require_verification=True)
        assert root_sorry.result['pass'], "Cannot parse a `sorry` tactic on root."
        self.root_goal = root_sorry.result['sorries'][-1]['goal']

        # other initialization
        self._last_selected_node = root
        
        return root, sample_count, yield_cache
    
    def _tree_new_child(self, parent):
        return self.node_cls(
            parent=parent,
            depth=parent['depth'] + 1,
        )
    
    def _tree_step(self, node, edge, code):
        if edge not in node.children:
            new_node = self._tree_new_child(node)
            node.children[edge] = new_node
            self.node_list.append(new_node)
        child_node = node.children[edge]
        child_node.update_code(code)
        return child_node
    
    def _tree_update(self, proof):
        node_walk, partial_code = self.root, str()

        # use tactic goals as tree edges
        # print(proof)
        segments = proof.segmentation()
        prev_goal = self.root_goal
        for info in segments:
            partial_code += info.tactic_code
            code = partial_code + info.state_comment if self.tactic_state_comment else partial_code
            if self._encode_length(code) < self.max_tokens:
                if info.goal != prev_goal:
                    node_walk = self._tree_step(
                        node=node_walk, edge=info.goal,
                        code=dict(tactic_code=partial_code, state_comment=info.state_comment)
                    )
                    prev_goal = info.goal
        return node_walk
    
    # algorithm pipeline
    def _select_node(self):
        node = self.root
        while len(node.children) > 0:
            num_choice = 1 + len(node.children)
            total_visitation = node.visitation + np.sum([child.subtree_visitation for _, child in node.children.items()])
            def hoeffding_ucb(node_visitation):
                return math.sqrt(2.0 * math.log(max(total_visitation, 2)) / max(node_visitation, 1e-2))
            choice_list = [(node.value + hoeffding_ucb(node.visitation), None)]
            for _, child in node.children.items():
                choice_list.append((child.subtree_value + hoeffding_ucb(child.subtree_visitation), child))
            choice_list.sort(reverse=True, key=lambda x: x[0])
            if choice_list[0][1] is None:
                node.start_new_job(gamma=self.gamma)
                return node
            else:
                node = choice_list[0][1]
        node.start_new_job(gamma=self.gamma)
        return node
    
    def _tactic_tree_generate_proof(self, data, node):
        code_prefix = node.code
        print(code_prefix)
        extra_prompt = code_prefix['tactic_code']
        if self.tactic_state_comment:
            extra_prompt += code_prefix['state_comment']
        return dict(
            node=node,
            code_prefix=code_prefix,
            generator_request_id=self.scheduler.generator_submit_request(
                self._preprocess_data({**data, '_extra_prompt': extra_prompt}),
            ),
        )
    
    def _tactic_tree_parse_proof(self, node, code_prefix, generator_request_id):
        code = self.scheduler.generator_get_request_status(generator_request_id)
        if code is None:
            return None
        code = code_prefix['tactic_code'] + code
        proof = self.proof_summarizer.analyze(code, require_verification=True)
        return dict(node=node, proof=proof)
    
    def _rmax_exploration_summarize_results(self, node, proof):
        if not proof.is_result_ready():
            return None
        
        num_nodes_before = len(self.node_list)
        self._tree_update(proof)
        # RMax reward
        node.update_reward(int(len(self.node_list) > num_nodes_before), gamma=self.gamma)
        node.complete_job(gamma=self.gamma)
        
        return dict(
            code=proof.cleaned_code,
            result=proof.result,
        )
    
    # sampler interface
    def sample(self, data, prob_log_dir, **kwargs):
        self.ckpt_path = os.path.join(prob_log_dir, self.ckpt_filename)
        self.root, sample_count, yield_cache = self._tree_setup(data)
        self.node_list = self.root.to_node_list()
        for _proposal, _sample_info in yield_cache:
            yield _proposal, _sample_info
        gc.collect()  # release memory

        job_slots = [
            ConcurrentJob(self.algorithm_pipeline)
            for _ in range(self.concurrent_num)
        ]

        sample_budget = self.sample_num - sample_count if len(yield_cache) == 0 else 0
        while (sample_budget > 0) or any([not job.is_idle() for job in job_slots]):
            for job in job_slots:
                if job.is_idle() and sample_budget > 0:
                    node = self._select_node()
                    self._last_selected_node = node
                    job.start(data=data, node=node)
                    sample_budget -= 1
                if not job.is_idle():
                    info = job.get_status()
                    if info is not None:
                        # output samples
                        sample_count += 1
                        if info['result']['complete']:
                            _proposal, _sample_info = info['code'], self._post_sample_info(
                                cost=sample_count, tree_size=len(self.node_list),
                            )
                            yield_cache.append((_proposal, _sample_info))
                            yield _proposal, _sample_info
                        # logging
                        if sample_count % self.log_interval == 0:
                            self.process_print('Progress: {} / {}    Tree Size: {}'.format(
                                sample_count, self.sample_num, len(self.node_list),
                            ))
                        # saving checkpoints
                        if sample_count % self.ckpt_interval == 0:
                            self._save_ckpt(dict(
                                root=self.root.to_dict(),
                                sample_count=sample_count,
                                yield_cache=yield_cache,
                            ))
                            if len(yield_cache) > 0:
                                # return after saving the checkpoint
                                # avoid overestimation caused by interrupt-restart loop
                                sample_budget = 0
            time.sleep(0.1)
        
        # save the final tree structure
        self._save_ckpt(dict(
            root=self.root.to_dict(),
            sample_count=sample_count,
            yield_cache=yield_cache,
        ))
