from copy import deepcopy
import sys
sys.path.append('./')
from experiments.t511b_verifier import Verifier

from collections import defaultdict
from itertools import product
import math
import numpy as np
from models.engine import Engine
import random
import tqdm
import pickle as pkl
import json
from typing import Dict, List, Tuple
from gadgets.lexical_diversity import lexical_diversity
import scipy.stats as stats
from sklearn.linear_model import LogisticRegression
import os
from gadgets.util import gpt3wrapper, convert_cmp_hs, classify_cmp, convert_cmp_to_ind
from models.preprocess import construct_blocks, prefix_subspan
from transformers import AutoTokenizer
from argparse import ArgumentParser
import argparse
from scipy.stats import norm


DEBUG = False
tok = AutoTokenizer.from_pretrained('gpt2-medium')
VERIFY_HYP_BLOCK_SIZE = 32
eps = 1e-5

def calculate_diff_w_significance(pos_scores, neg_scores, alpha=1e-5):
    pos_scores = np.array(pos_scores)
    neg_scores = np.array(neg_scores)
    mu = np.mean(pos_scores) - np.mean(neg_scores)
    p_value = stats.ttest_ind(pos_scores, neg_scores)[1]
    mu_variance = np.var(pos_scores) / len(pos_scores) + np.var(neg_scores) / len(neg_scores)
    mu_std = np.sqrt(mu_variance)
    target_z = norm.ppf(1 - alpha / 2)
    lo, hi = mu - target_z * mu_std, mu + target_z * mu_std
    optimistic_discriminative_power = max(np.abs(lo), np.abs(hi))
    pessimistic_discriminative_power = min(np.abs(lo), np.abs(hi))
    if lo * hi < 0:
        pessimistic_discriminative_power = 0
    return {
        'mu': mu,
        'p_value': p_value,
        'mu_std': mu_std,
        'lo': lo,
        'hi': hi,
        'optimistic_discriminative_power': optimistic_discriminative_power,
        'pessimistic_discriminative_power': pessimistic_discriminative_power
    }


class DistributionPairInstance0110:

    def __init__(
        self,
        application,
        verifier,
        proposer,
        top_fraction=None,
        total_hypotheses_count=30
    ):
        self.orig_pos2score, self.orig_neg2score = application['pos2score'], application['neg2score']
        self.orig_sent2membership = {}
        for sent in pos2score:
            self.orig_sent2membership[sent] = 1.
        for sent in neg2score:
            self.orig_sent2membership[sent] = 0.

        self.proposer, self.verifier = proposer, verifier

        self.h2h_dicts = {}

        self.top_fraction = top_fraction
        if top_fraction is None:
            self.top_fraction = [0.05, 0.2, 1.0]
        self.total_hypotheses_count = total_hypotheses_count

    def get_hypotheses(self):
        sorted_pos = sorted(self.orig_pos2score, key=self.orig_pos2score.get, reverse=True)
        sorted_neg = sorted(self.orig_neg2score, key=self.orig_neg2score.get, reverse=True)

        for idx in range(3):
            for p in self.top_fraction:
                if len(self.h2h_dicts) >= self.total_hypotheses_count:
                    break
                pos, neg = lexical_diversity(sorted_pos, sorted_neg, top_p=p, num_sentences=25)
                hyps, provenance = self.proposer.propose_hypotheses(pos, neg)
                provenance['top_p'] = p
                provenance['idx'] = idx
                for hyp in hyps:
                    if hyp not in self.h2h_dicts and len(self.h2h_dicts) < self.total_hypotheses_count:
                        h_dict = {
                            'hypothesis': hyp,
                            'sent2score': {}, 
                            'active': True,
                            'provenance': provenance,
                            'diff_w_significance': None
                        }
                        self.h2h_dicts[hyp] = h_dict

    def get_correlation_info(self):
        for h in self.h2h_dicts:
            hyp_dict = self.h2h_dicts[h]
            ordered_text = sorted(hyp_dict['sent2score'], key=hyp_dict['sent2score'].get)

            pos_scores = [hyp_dict['sent2score'][sent] for sent in ordered_text if self.orig_sent2membership[sent] == 1.]
            neg_scores = [hyp_dict['sent2score'][sent] for sent in ordered_text if self.orig_sent2membership[sent] == 0.]
            self.h2h_dicts[h]['diff_w_significance'] = calculate_diff_w_significance(pos_scores, neg_scores)
    

    def filter_weak_hypotheses(self, K=5):
        pessimistic_bounds = [hyp_dict['diff_w_significance']['pessimistic_discriminative_power'] for hyp_dict in self.h2h_dicts.values()]
        top_K_pessimistic_bounds = sorted(pessimistic_bounds, reverse=True)[:K]

        for h, hyp_dict in self.h2h_dicts.items():
            if hyp_dict['active']:
                if hyp_dict['diff_w_significance']['optimistic_discriminative_power'] < top_K_pessimistic_bounds[-1]:
                    hyp_dict['active'] = False


    def verify_active(self):
        random_sent_order = list(self.orig_sent2membership.keys())
        random.shuffle(random_sent_order)

        cur_pointer = 0

        print('Filtering out weak hypotheses')

        # enumerate the sentences in random order
        with tqdm.tqdm(total=len(random_sent_order)) as pbar:
            while cur_pointer < len(random_sent_order):

                # take a batch of sentences, and compute a score for every competitive hypotheses
                sents = random_sent_order[cur_pointer:cur_pointer+VERIFY_HYP_BLOCK_SIZE]
                cur_pointer += VERIFY_HYP_BLOCK_SIZE

                # construct the verifier dicts
                verifier_dicts = []
                for sent in sents:
                    for h, hyp_dict in self.h2h_dicts.items():
                        if not hyp_dict['active']:
                            continue
                        verifier_dict = {'hypothesis': h, 'text': sent, 'type': 'ind', 'pointer': hyp_dict}
                        verifier_dicts.append(verifier_dict)
                
                # run the verifier 
                all_scores = list(self.verifier.verify_ind_dicts_w_scores(verifier_dicts))
                assert len(all_scores) == len(verifier_dicts)
                for d, s in zip(verifier_dicts, all_scores):
                    d['pointer']['sent2score'][d['text']] = s + eps * random.random()
                
                # filter out weaker hypotheses based on UCB
                pbar.update(len(sents))
                self.get_correlation_info()
                self.filter_weak_hypotheses()

                pbar.set_description('Num hypotheses: %d' % len([h for h in self.h2h_dicts if self.h2h_dicts[h]['active']]))


    def run(self):

        self.get_hypotheses()
        # run the verifier, and adaptively deciding which hypothesis we want to keep
        self.verify_active()
        return {
            'hypotheses': self.h2h_dicts
        }


def subsample(sent2score: Dict[str, float], subsample_size=1000) -> Dict[str, float]:
    if len(sent2score) <= subsample_size:
        return sent2score
    all_sents = list(sent2score.keys())
    random.shuffle(all_sents)
    return {sent: sent2score[sent] for sent in all_sents[:subsample_size]}



DEFAULT_HYPOTHESES = [
    "talks about politics, such as presidential election.",
    "contains insulting language for immigrants.",
    "uses double negation, i.e., using two negations in a sentence."
]
SINGLE_SAMPLE_MAX_LENGTH = 256
MAX_PROMPT_LENGTH = 3200

class Proposer0110:

    def __init__(self, application, use_default_hypotheses=False, single_max_length=SINGLE_SAMPLE_MAX_LENGTH, engine_name='text-davinci-003', temperature=0.7):
        if use_default_hypotheses:
            self.example_hypotheses = DEFAULT_HYPOTHESES
        else:
            self.example_hypotheses = (application['example_hypotheses'] + DEFAULT_HYPOTHESES)[:3]
        
        self.application = application
        self.prompt_template = open('models/templates/1228_w_hypotheses.txt', 'r').read()
        self.single_max_length = single_max_length
        self.engine_name = engine_name
        self.temperature = temperature
    
    def propose_hypotheses(self, pos_sents, neg_sents):
        dataset_description = application['dataset_description']
        generation = application['generation']
        positive_description = application['pos_desc']
        negative_description = application['neg_desc']
        
        target = application['target']
        user = application['user']

        num_incontext_samples = 25
        prompt = None

        arg_dict = {
            'dataset_description': dataset_description,
            'generation': generation,
            'positive_description': positive_description,
            'negative_description': negative_description,
            'user': user,
            'target': target
        }
        random.shuffle(self.example_hypotheses)
        for i, hypothesis in enumerate(self.example_hypotheses):
            arg_dict[f'example_hypothesis_{i+1}'] = hypothesis

        while num_incontext_samples > 1:
            pos_samples, neg_samples = pos_sents, neg_sents

            A_sentences = [prefix_subspan(x, self.single_max_length, tok) for x in pos_samples]
            B_sentences = [prefix_subspan(x, self.single_max_length, tok) for x in neg_samples]

            sent_subset = construct_blocks(A_sentences, B_sentences, num_incontext_samples=num_incontext_samples, truncate=False)
            
            A_block, B_block = sent_subset['A_block'], sent_subset['B_block']
            tmp_arg_dict = deepcopy(arg_dict)
            tmp_arg_dict['A_block'] = A_block
            tmp_arg_dict['B_block'] = B_block
            prompt = self.prompt_template.format(**tmp_arg_dict)
            prompt_length = len(tok.encode(prompt))
            if prompt_length < MAX_PROMPT_LENGTH:
                break
            else:
                num_incontext_samples -= 1
                print('prompt too long, reducing num_incontext_samples to %d' % num_incontext_samples)

        arg_dict['A_block'] = sent_subset['A_block']
        arg_dict['B_block'] = sent_subset['B_block']
        prompt = self.prompt_template.format(**arg_dict)

        query_args = {
            'engine': self.engine_name,
            'prompt': prompt,
            'temperature': self.temperature,
            'max_tokens': 512,
            'top_p': 1,
            'n': 1
        }

        result = gpt3wrapper(
            tag='proposer',
            **query_args
        )

        returned_text = result['choices'][0]['text']

        hs = []
        for h in returned_text.split('\n\n')[0].split('\n-'):
            h = convert_cmp_to_ind(h.replace('"', '').strip())
            if h is not None:
                if h[-1] == '.':
                    h = h[:-1]
                hs.append(h)

        return hs, query_args


class DummyVerifier:

    def __init__(self):
        pass
    
    def verify_ind_dicts_w_scores(self, ind_dicts):
        for _ in range(len(ind_dicts)):
            yield 0.01#random.random() < 0.5


def flip_application(application):
    application = deepcopy(application)
    application['pos_desc'], application['neg_desc'] = application['neg_desc'], application['pos_desc']
    application['pos_samples'], application['neg_samples'] = application['neg_samples'], application['pos_samples']
    application['split']['train']['pos_samples'], application['split']['train']['neg_samples'] = application['split']['train']['neg_samples'], application['split']['train']['pos_samples']
    application['split']['test']['pos_samples'], application['split']['test']['neg_samples'] = application['split']['test']['neg_samples'], application['split']['test']['pos_samples']
    if 'pos2score' in application:
        application['pos2score'], application['neg2score'] = application['neg2score'], application['pos2score']
    return application

if __name__ == '__main__':
    applications = pkl.load(open('data/benchmark_applications_2nddraft.pkl', 'rb'))

    # applications = [x for x in applications if x['purely_exploratory']]
    model_path = 'models/ckpts/best_verifier/'
    
    for application in applications:
        pair_id = application['pair_id']
        extreme_results = pkl.load(open('experiments/pair_extreme_backup/results/v1-pair_id%d/result.pkl' % int(pair_id), 'rb'))
        application['pos2score'], application['neg2score'] = subsample(extreme_results['pos2score'], 1000), subsample(extreme_results['neg2score'], 1000)

    if DEBUG:
        verifier = DummyVerifier()
    else:
        verifier = Verifier(model_path=model_path)

    application_l = []
    for application in applications:
        application['orientation'] = 'pos'
        application_l.append(application)
        if application['flip']:
            new_application = flip_application(application)
            new_application['orientation'] = 'neg'
            application_l.append(new_application)

    for application in application_l:
        print('===========================================================================================')
        proposer = Proposer0110(application)
        application_id = application['v2_id']
        save_path = 'application_results/proposer_results_%s_%s.pkl' % (application_id, application['orientation'])
        if os.path.exists(save_path) and not DEBUG:
            continue
        print('save_path', save_path)
        pkl.dump('lock', open(save_path, 'wb'))
        pos2score, neg2score = extreme_results['pos2score'], extreme_results['neg2score']
        pos2score, neg2score = subsample(pos2score, 2000), subsample(neg2score, 2000)

        dpi = DistributionPairInstance0110(application=application, proposer=proposer, verifier=verifier, total_hypotheses_count=30 if not DEBUG else 3)
        result = dpi.run()
        result['application_id'] = application_id
        result['flip'] = application['flip']
        pkl.dump(result, open(save_path, 'wb'))
