#!/usr/bin/env python3

import argparse
import pickle
import collections
import os
import json
import torch
import numpy as np
from dataclasses import dataclass
import random

from omegaconf import DictConfig
from tqdm import tqdm

import peano
import proofsearch
from util import plot_vegalite, translate_object, format_blocks_with_indent


@dataclass
class Run:
    id: str
    agents: list[torch.nn.Module]
    outcomes: list[list[dict]]
    theory: str
    hindsight: bool
    logprobs: list[dict]


def load_run(run_dir):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    n_iters = 0

    while os.path.exists(os.path.join(run_dir, f'outcomes_{n_iters}.json')):
        n_iters += 1

    # Load all agents with torch.load (i.pt), and all outcomes,
    # in outcomes_i.json.
    agents, outcomes = [], []
    for i in range(n_iters):
        agents.append(None) # torch.load(os.path.join(run_dir, f'{i}.pt'), map_location=device))
        with open(os.path.join(run_dir, f'outcomes_{i}.json'), 'r') as f:
            outcomes.append(json.load(f))

    # Outcomes are cumulative on disk -- make them per-iteration.
    for i in range(n_iters - 1, 0, -1):
        outcomes[i] = [o for o in outcomes[i] if o['iteration'] == i]

    return agents, outcomes


def load_runs_from_specs(specs_path):
    with open(specs_path, 'r') as f:
        run_specs = json.load(f)

    runs = []

    for spec in run_specs:
        run_dir = spec['run_dir']
        theory = spec['theory']
        hindsight = spec['hindsight']
        agents, outcomes = load_run(run_dir)

        # Load logprobs.
        with open(os.path.join(run_dir, 'logprobs.json'), 'r') as f:
            logprobs = json.load(f)
            for l in logprobs:
                l['run_id'] = run_dir
                l['Theory'] = theory
                l['Run Hindsight'] = hindsight

        runs.append(Run(run_dir, agents, outcomes, theory, hindsight, logprobs))

    return runs


def plot_difficulty(specs_path):
    runs = load_runs_from_specs(specs_path)

    all_logprobs = [l for r in runs if r.hindsight for l in r.logprobs]
#                    if (not l['Hindsight'])]

    plot_vegalite(
        'difficulty',
        all_logprobs,
        f'difficulty',
        {'$TITLE': f'Problem difficulty across iterations'}
    )


def plot_curriculum(specs_path):
    runs = load_runs_from_specs(specs_path)

    all_logprobs = [l for r in runs for l in r.logprobs
                    if l['Agent Iteration'] == l['Conjecture Iteration']
                    and (not l['Hindsight'])]

    all_logprobs = [{**l,
                     "Iteration": "Conjecture Iteration",
                     "Run Hindsight": ["Without Hindsight", "With Hindsight"][l['Run Hindsight']]}
                    for l in all_logprobs]

    plot_vegalite(
        'curriculum',
        all_logprobs,
        f'curriculum',
        {'$TITLE': f'Difficulty of new provable conjectures at each iteration'}
    )


def plot_hard_conjectures(run_dir, run_name):
    agent, outcomes = load_run(run_dir)

    # Load logprobs.json
    with open(os.path.join(run_dir, 'logprobs.json'), 'r') as f:
        logprobs = json.load(f)

    # Get hardest conjectures for each iteration for each agent.
    all_proof_logprobs = collections.defaultdict(list)

    for record in logprobs:
        agent_it = record['Agent Iteration']
        conjecture_it = record['Conjecture Iteration']

        if agent_it == conjecture_it:
            all_proof_logprobs[agent_it].append(record['Proof logprob'])

    datapoints = []

    for it, pl in all_proof_logprobs.items():
        mean_pl = sum(pl) / len(pl)
        top_20_percentile = np.percentile(pl, 20)
        top_10_percentile = np.percentile(pl, 10)

        datapoints.append({
            'Iteration': it,
            'Proof logprob': mean_pl,
            'Reference': 'Mean',
        })

        datapoints.append({
            'Iteration': it,
            'Proof logprob': top_20_percentile,
            'Reference': 'Hardest 20%',
        })

        datapoints.append({
            'Iteration': it,
            'Proof logprob': top_10_percentile,
            'Reference': 'Hardest 10%',
        })

    plot_vegalite(
        'hard_conjectures',
        datapoints,
        f'hard-conjectures-{run_name}',
        {'$TITLE': f'{run_name} :: Hard conjectures'}
    )


def plot_provable_ratio(run_specs):
    runs = load_runs_from_specs(run_specs)

    all_logprobs = [{**o,
                     'Proved': int(bool(o['proof'] is not None)),
                     'Iteration': it,
                     'Theory': r.theory,
                     'Hindsight': ['Without Hindsight', 'With Hindsight'][r.hindsight]
                     }
                    for r in runs for it, o_i in enumerate(r.outcomes) for o in o_i
                    if (not o['hindsight'])]

    plot_vegalite(
        'provable',
        all_logprobs,
        f'provable',
        {'$TITLE': f'Fraction of proven conjectures by iteration'}
    )


def evaluate_logprobs(run_dir, theory_path, premises):
    agents, outcomes = load_run(run_dir)
    iterations = len(agents)

    with open(theory_path, 'r') as f:
        theory = f.read()

    premises = premises.split(',')

    records = []
    seen_problems = set()

    print(iterations, 'iterations, theory =', theory_path, '& premises =', premises)

    for j, outcomes_j in enumerate(outcomes):
        print('Iteration', j, 'with', len(outcomes_j), 'outcomes.')

        for out in tqdm(outcomes_j):
            if not out['actions'] or not out['proof']:
                continue

            if out['problem'] in seen_problems:
                continue

            seen_problems.add(out['problem'])

            state = peano.PyProofState(theory, premises, out['problem'])
            actions = out['actions']
            hindsight = out['hindsight']

            for it in range(iterations):
                root = proofsearch.TreeSearchNode(proofsearch.HolophrasmNode([state]))
                try:
                    logprob = root.solution_logprob_under_policy(agents[it]._policy,
                                                                 out['actions'])
                except Exception as e:
                    break
                records.append({
                    'Problem': out['problem'],
                    'Agent Iteration': it,
                    'Conjecture Iteration': j,
                    'Proof logprob': logprob,
                    'Hindsight': hindsight,
                })

    with open(os.path.join(run_dir, 'logprobs.json'), 'w') as f:
        json.dump(records, f)


def plot_extrinsic_success(run_paths, run_names):
    data = []

    for i, run_path in enumerate(run_paths):
        with open(run_path) as f:
            outcomes = json.load(f)
            for o in outcomes:
                data.append({
                    'Theory': run_names[i],
                    'checkpoint': o['checkpoint'],
                    'success': o['success'],
                    'problem': o['problem'],
                })

    plot_vegalite(
        'extrinsic',
        data,
        f'extrinsic',
        {'$TITLE': f'Success rate on human-written theorems'}
    )

    solved_by_each = collections.defaultdict(set)

    iterations = []

    for o in outcomes:
        if o['checkpoint'] not in iterations:
            iterations.append(o['checkpoint'])
        if o['success']:
            solved_by_each[o['checkpoint']].add(o['problem'])

    iterations.sort()

    # Print new problems solved by each iteration.
    for i in iterations[1:]:
        # Compute set difference.
        new_problems = solved_by_each[i] - solved_by_each[i - 1]
        if new_problems:
            print(f'New problems at iteration {i}:', new_problems)


def _find_files_with_name(name):
    result = []
    for root, dirs, files in os.walk('.'):
        if name in files:
            result.append(os.path.join(root, name))
    return result


def plot_mcts_iterations_by_logprob():
    datapoints = []

    for path in _find_files_with_name('mcts_iterations_by_logprob.json'):
        print('Found', path)
        with open(path, 'r') as f:
            data = json.load(f)

        for d in data:
            assert d['Iterations'] is not None
            assert d['Logprob'] is not None
            datapoints.append({'MCTS Iterations': d['Iterations'], 'Proof log-likelihood': d['Logprob']})

    print(len(datapoints), 'datapoints.')

    plot_vegalite(
        'mcts_iterations',
        datapoints,
        f'mcts_iterations',
        {'$TITLE': f'MCTS iterations by logprob'}
    )


def print_proofs(run_dir):
    _, outcomes = load_run(run_dir)

    for i, outcomes_i in enumerate(outcomes):
        print(f'Iteration {i}:')

        for o in outcomes_i:
            if o['proof']:
                print('Problem:', o['problem'])
                print('Proof:')
                print(format_blocks_with_indent(o['proof']))
                print()


def run_mcts_iterations_by_logprob(run_dir, theory_path, max_per_iteration=25):
    agents, outcomes = load_run(run_dir)

    with open(theory_path, 'r') as f:
        theory = f.read()

    with open(theory_path + '.premises', 'r') as f:
        premises = f.read().splitlines()

    data_points = []

    for i, outcomes_i in enumerate(outcomes):
        print(f'Iteration {i}:')
        random.shuffle(outcomes_i)
        dp = 0

        for o in outcomes_i:
            if o['hindsight']:
                continue

            it = o['iteration']

            if it != i:
                continue

            if dp >= max_per_iteration:
                break

            if o['proof']:
                agent = agents[it]
                state = peano.PyProofState(theory, premises, o['problem'])
                root = proofsearch.TreeSearchNode(proofsearch.HolophrasmNode([state]))
                agent._max_mcts_nodes = 2000
                result = agent.proof_search(o['problem'], state)

                if result.success:
                    dp += 1
                    iterations = result.iterations
                    data_points.append({
                        'Iteration': it,
                        'Problem': o['problem'],
                        'Iterations': iterations,
                        'Logprob': o['logprob'],
                    })

    with open(os.path.join(run_dir, 'mcts_iterations_by_logprob.json'), 'w') as f:
        json.dump(data_points, f)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Analyze the results of a run.')
    parser.add_argument('--run', type=str, help='The output directory of the run to analyze.')
    parser.add_argument('--run-name', type=str, help='The name of the run.')
    parser.add_argument('--run-specs', type=str, help='JSON spec of the runs.')
    parser.add_argument('--theory', type=str, help='The background theory to use for analysis.')
    parser.add_argument('--premises', type=str, help='Premises used for proof search.')
    parser.add_argument('--evaluate-logprobs', action='store_true', help='Evaluate logprobs of each solution against each agent.')
    parser.add_argument('--plot-difficulty', action='store_true',
                        help='Plot solution logprob across iterations for each policy.')
    parser.add_argument('--plot-curriculum', action='store_true',
                        help='Plot difficulty of new problems for current policy at each iteration.')
    parser.add_argument('--plot-hard-conjectures', action='store_true',
                        help='Plot hard conjecture logprobs across iterations.')
    parser.add_argument('--plot-provable-ratio', action='store_true',
                        help='Plot the fraction of provable conjectures across iterations for each run.')
    parser.add_argument('--plot-extrinsic-success', action='store_true',
                        help='Plot the success rate across iterations on extrinsic eval.')
    parser.add_argument('--mcts-iterations-by-logprob', action='store_true',
                        help='Measure MCTS iterations by logprob.')
    parser.add_argument('--plot-mcts-iterations-by-logprob', action='store_true',
                        help='Measure MCTS iterations by logprob.')
    parser.add_argument('--run-dirs', type=str, nargs='*', help='The output directories of the runs to analyze.')
    parser.add_argument('--run-names', type=str, nargs='*', help='The names of the runs.')
    parser.add_argument('--print-proofs', action='store_true', help='Print proofs.')

    args = parser.parse_args()

    if args.evaluate_logprobs:
        evaluate_logprobs(args.run, args.theory, args.premises)

    if args.plot_difficulty:
        plot_difficulty(args.run or args.run_specs)

    if args.plot_curriculum:
        plot_curriculum(args.run or args.run_specs)

    if args.plot_hard_conjectures:
        plot_hard_conjectures(args.run, args.run_name)

    if args.plot_provable_ratio:
        plot_provable_ratio(args.run_specs)

    if args.print_proofs:
        print_proofs(args.run)

    if args.plot_extrinsic_success:
        plot_extrinsic_success(args.run_dirs, args.run_names)

    if args.mcts_iterations_by_logprob:
        run_mcts_iterations_by_logprob(args.run, args.theory)

    if args.plot_mcts_iterations_by_logprob:
        plot_mcts_iterations_by_logprob()
