import numpy as np
import torch
from tqdm import tqdm

from scipy.stats import bernoulli
from nn_utils import train_g, compute_scores

import warnings
warnings.filterwarnings("ignore")

###### FSFT Algorithm ######
class StaticBaseline:
    def __init__(self, stream, exp_config):
     
        # print(f'\n====== Static Baseline ======')

        # exp config
        self.q = exp_config.q
        self.metric = exp_config.metric

        # data stream attributes
        self.stream = stream
        self.n = len(self.stream)

        # static threshold
        if self.metric == 'tpr':
            self.static_lambda = self.stream.id_quantile_score(q=1-self.q)
        if self.metric == 'fpr':
            self.static_lambda = self.stream.ood_quantile_score(q=1-self.q)

        # experiment results
        self.out_all = {
            'expected_tpr': [], 
            'expected_fpr': [],
            'instances': [],
            'true_label': [], 
            'pred_label': [],
        }

    def run(self):
        for t in tqdm(range(self.n)):
            # get data
            _, s, y = self.stream[t]
            
            # increment the internal counter of stream
            self.stream.inc_pos()

            # predication
            y_hat = 0 if s <= self.static_lambda else 1

            # record
            self.out_all['true_label'].append(y)
            self.out_all['pred_label'].append(y_hat)
            if y == 0 and y_hat == 0:
                self.out_all['instances'].append('tn')
            elif y == 0 and y_hat == 1:
                self.out_all['instances'].append('fp')
            elif y == 1 and y_hat == 0:
                self.out_all['instances'].append('fn')
            elif y == 1 and y_hat == 1:
                self.out_all['instances'].append('tp')

            self.out_all['expected_fpr'].append(self.stream.fpr(self.static_lambda) * 100)
            self.out_all['expected_tpr'].append(self.stream.tpr(self.static_lambda) * 100)

###### Stationary Settings: FSAT and ASAT Algorithms ######
class ScoreFunctionThresholdEstimator:
    def __init__(self, stream, exp_config, seed, device=None):
        # method id
        self.method_id = exp_config.method_id
        # print(f'\n====== Starting : Method-{self.method_id} ======')

        # set seed
        self.seed = seed

        # device
        self.device = device

        # threshold estimation parameters
        self.alpha = exp_config.alpha
        self.delta = exp_config.delta
        self.p = exp_config.p
        self.epsilon = 10**-5
        self.ucb = exp_config.ucb                       # ucb interval methods
        
        # joint optimization on g and lambda parameters
        self.mode_estimator = exp_config.mode_estimator
        self.training_params = exp_config.training_params
        self.beta = exp_config.beta
        self.c = exp_config.c
        self.input_size = exp_config.input_size
        self.num_epoch = exp_config.num_epoch
        self.batch_size = exp_config.batch_size
        self.update_freq = exp_config.update_freq           # number of new OOD points required to update g 
        self.num_update_max = len(self.update_freq) if self.update_freq != None else 0      # number of maximum updates on g
        self.c_heuristic = exp_config.c_heuristic           # (C1, C2, C3)

        # data stream attributes
        self.stream = stream
        self.n = len(self.stream)
        self.min_lam = self.stream.ood_quantile_score(q=0.5)    # score min w.r.t. g0
        self.max_lam = self.stream.max_score()                  # score max w.r.t. g0

        # get ID training data from classification task (we assume ID distr does NOT change)
        self.z_id_train = self.stream.get_z1_train()

        # experiment results
        self.out_all = {
            'threshold': [], 
            'expected_tpr': [],
            'expected_fpr': [],
            'fpr_estimate': [],
            'feasibility_point': None,
            'update_points': [],
            'instances': [],
            'true_label': [], 
            'pred_label': [],
        }

    def ucb_theory(self, i_ood, eta=None):
        i_ood = np.array(i_ood)
        n_pred, n_imp = np.sum(i_ood==0), np.sum(i_ood==1)
        c_t = n_pred + n_imp/(self.p**2)
        n_t = n_pred + n_imp/self.p

        try:
            return np.sqrt((3*c_t/(n_t**2)) * (2*np.log(np.log(3*c_t/2)) + np.log(2/self.delta)))
        except (ValueError, ZeroDivisionError):
            return 0

    def ucb_heuristic(self, i_ood):
        i_ood = np.array(i_ood)
        n_pred, n_imp = np.sum(i_ood==0), np.sum(i_ood==1)
        c_t = n_pred + n_imp/(self.p**2)
        n_t = n_pred + n_imp/self.p
        c_1, c_2, c_3 = self.c_heuristic

        try:
            return c_1 * np.sqrt((c_t/(n_t**2)) * (np.log(np.log(c_2 * c_t)) + np.log(c_3/self.delta)))
        except (ValueError, ZeroDivisionError):
            return 0

    def dkw_bound(self, n, delta):
        return np.sqrt((1/(2*n)) * np.log(2/delta))

    def fpr_est(self, lam, s0_lst, i0_lst):
        num_ood = 0
        sum = 0

        for s,i in zip(s0_lst, i0_lst):
            if i == 0: # no importance sampling
                num_ood += 1
                if s > lam:
                    sum += 1
            if i == 1:  # yes important sampling
                num_ood += int(1/self.p)
                if s > lam:
                    sum += 1/self.p
        return sum/num_ood

    def tpr_est_iid(self, lam, s1_lst):
        return sum(1 for s in s1_lst if s > lam) / len(s1_lst)

    def binary_search(self, psi, s0_lst, i0_lst, lam_prev, fpr_est_prev):
        high = min(lam_prev, self.max_lam)
        low = self.min_lam 
        lam_hat = high
        fpr_final = fpr_est_prev
        feasible = False

        while high - low > self.epsilon:
            mid = (high + low)/2
            fpr_hat = self.fpr_est(mid, s0_lst, i0_lst)
            
            if fpr_hat + psi < self.alpha:
                # valid lambda exists in the range (low, high)
                feasible = True
                fpr_final = fpr_hat
                high = mid              # take the first half
                lam_hat = mid           # set the current estimate to be the highest lambda
            else:
                # valid lambda does NOT exist in the range (low, high)
                low = mid
        return lam_hat, feasible, fpr_final

    def run(self):
        # record in each cycle T^(i)
        out = {
            'z_id': [],'z_ood': [],        
            's_id': [],'s_ood': [],
            'i_id': [],'i_ood': [],
        }

        # initial threshold
        self.g = None
        lam_prev = np.inf
        fpr_est_prev = 0
        num_update = 0
        num_ood_curr = 0

        for t in tqdm(range(self.n)):

            q = 0 # whether or not we get feedback/query
            i = 0 # whether or not we used importance sampling

            # get data
            z, s, y = self.stream[t]
            
            # increment the internal counter of stream
            self.stream.inc_pos()

            if s <= lam_prev:
                y_hat = 0
                q = 1
                i = 0
            else:
                y_hat = 1
                q = bernoulli.rvs(self.p, size=1)[0]
                i = 1

            # is out-of-distribution & got feedback as out-of-distribution
            if y == 0 and q == 1:
                if self.out_all['feasibility_point'] != None:
                    num_ood_curr += 1

                out['z_ood'].append(z)
                out['s_ood'].append(s)
                out['i_ood'].append(i)
                assert len(out['z_ood']) == len(out['s_ood']) == len(out['i_ood'])

            # is in-distribution & got feedback as in-distribution
            if y == 1 and q == 1:
                pass

            # is in-distribution OR did not get feedback
            if y == 1 or q == 0:
                pass

            # is correctly identified as out-of-distribution
            if y == 0 and q == 1:
                # compute the confidence interval
                if self.ucb == 'heuristic':
                    psi = self.ucb_heuristic(out['i_ood'])
                if self.ucb == 'theory':
                    psi = self.ucb_theory(out['i_ood'])
                
                # find optimal lambda that satisfies the fpr estimate constraint
                lam_curr, feasible, fpr_est_curr = self.binary_search(psi, out['s_ood'], out['i_ood'], lam_prev, fpr_est_prev)
                
                if feasible:
                    if self.out_all['feasibility_point'] == None:
                        self.out_all['feasibility_point'] = t
                    lam_prev = lam_curr
                    fpr_est_prev = fpr_est_curr

            # store instance and labels
            self.out_all['true_label'].append(y)
            self.out_all['pred_label'].append(y_hat)
            if y == 0:
                if y_hat == 0:
                    self.out_all['instances'].append('tn')
                elif y_hat == 1:
                    self.out_all['instances'].append('fp')
            elif y == 1:
                if y_hat == 0:
                    self.out_all['instances'].append('fn')
                elif y_hat == 1:
                    self.out_all['instances'].append('tp')
            
            # keep record
            self.out_all['expected_fpr'].append(self.stream.fpr(lam_prev) * 100)
            self.out_all['expected_tpr'].append(self.stream.tpr(lam_prev) * 100)
            self.out_all['threshold'].append(lam_prev)
            self.out_all['fpr_estimate'].append(fpr_est_prev * 100)
            
            # update g 
            if self.num_update_max > num_update and num_ood_curr >= self.update_freq[num_update]:

                num_ood_curr = 0
                num_update += 1

                if self.g != None:
                    assert num_update > 1
                    s1_old_lst = compute_scores(self.g, self.z_id_train, self.device)

                # (1) optimize for g and lambda simultaneously
                self.g, _, _ = train_g(
                    self.device,
                    self.method_id, self.g, self.mode_estimator,
                    self.training_params, self.beta, self.c, self.p, self.input_size, 
                    out['z_ood'], self.z_id_train, 
                    out['i_ood'], 
                    num_epoch=self.num_epoch, batch_size=self.batch_size, 
                    show_log=False, seed=self.seed)
                

                # (2) find safe threshold
                s0_new_lst = compute_scores(self.g, out['z_ood'], self.device)
                s1_new_lst = compute_scores(self.g, self.z_id_train, self.device)
                self.min_lam = np.min(s0_new_lst) - 0.1
                self.max_lam = np.max(s1_new_lst) + 0.1

                if self.ucb == 'heuristic':
                    psi = self.ucb_heuristic(out['i_ood'])
                if self.ucb == 'theory':
                    psi = self.ucb_theory(out['i_ood'])

                lam_curr, feasible, _ = self.binary_search(psi, s0_new_lst, out['i_ood'], np.inf, np.inf)
                
                # (3) model selection based on tpr estimation
                if feasible:
                    if num_update > 1:
                        old_tpr_est = self.tpr_est_iid(lam=lam_prev, s1_lst=s1_old_lst)
                        new_tpr_est = self.tpr_est_iid(lam=lam_curr, s1_lst=s1_new_lst)

                    if num_update == 1 or new_tpr_est + 2 * self.dkw_bound(n=len(self.z_id_train), delta=0.05) > old_tpr_est:
                        # update according to the new g
                        out['s_ood'] = s0_new_lst
                        lam_prev = lam_curr
                        self.stream.update_stream(self.g, self.device)
                        self.min_lam = self.stream.ood_quantile_score(q=0.5)
                        self.max_lam = self.stream.max_score()
                        self.out_all['update_points'].append(t)
                
                    else:
                        self.min_lam = self.stream.ood_quantile_score(q=0.5)
                        self.max_lam = self.stream.max_score()
                        self.out_all['update_points'].append(np.nan)

###### Distribution-Shift Settings: FSAT and ASAT Algorithms ######
class ScoreFunctionThresholdEstimatorDistrShft(ScoreFunctionThresholdEstimator):
    def __init__(self, stream, exp_config, seed, device=None):
        super().__init__(stream, exp_config, seed, device)

        # window sizes 
        self.window_est = exp_config.window_est

        # change detection
        self.out_all['changes_detected'] = []
    
    # override
    def binary_search(self, psi, s0_lst, i0_lst, lam_prev):
        high = self.max_lam
        low = self.min_lam 
        lam_hat = high
        fpr_final = np.inf
        feasible = False

        while high - low > self.epsilon:
            mid = (high + low)/2
            fpr_hat = self.fpr_est(mid, s0_lst, i0_lst)
            
            if fpr_hat + psi <= self.alpha:
                # valid lambda exists in the range (low, high)
                feasible = True
                fpr_final = fpr_hat
                high = mid              # take the first half
                lam_hat = mid           # set the current estimate to be the highest lambda
            else:
                # valid lambda does NOT exist in the range (low, high)
                low = mid
        return lam_hat, feasible, fpr_final

    # override
    def run(self):
 
        # record in each cycle T^(i)
        out = {
            'z_id': [],'z_ood': [],        
            's_id': [],'s_ood': [],
            'i_id': [],'i_ood': [],
            'instances': []
        }

        # initial threshold
        self.g = None
        lam_prev = np.inf
        fpr_est_prev = 0
        num_update = 0
        num_ood_curr = 0

        for t in tqdm(range(self.n)):

            q = 0 # whether or not we get feedback/query
            i = 0 # whether or not we used importance sampling

            # get data
            z, s, y = self.stream[t]
            
            # increment the internal counter of stream
            self.stream.inc_pos()

            if s <= lam_prev:
                y_hat = 0
                q = 1
                i = 0
            else:
                y_hat = 1
                q = bernoulli.rvs(self.p, size=1)[0]
                i = 1

            # is out-of-distribution & got feedback as out-of-distribution
            if y == 0 and q == 1:
                if self.out_all['feasibility_point'] != None:
                    num_ood_curr += 1

                out['z_ood'].append(z)
                out['s_ood'].append(s)
                out['i_ood'].append(i)
                assert len(out['z_ood']) == len(out['s_ood']) == len(out['i_ood'])

            # is in-distribution & got feedback as in-distribution
            if y == 1 and q == 1:
                pass

            # is in-distribution OR did not get feedback
            if y == 1 or q == 0:
                pass

            # is correctly identified as out-of-distribution
            if y == 0 and q == 1:
                
                # compute the confidence interval
                if self.ucb == 'heuristic':
                    psi = self.ucb_heuristic(out['i_ood'][-self.window_est:])
                if self.ucb == 'theory':
                    psi = self.ucb_theory(out['i_ood'][-self.window_est:])
                
                # find optimal lambda that satisfies the fpr estimate constraint
                lam_curr, feasible, fpr_est_curr = self.binary_search(psi, out['s_ood'][-self.window_est:], out['i_ood'][-self.window_est:], lam_prev)
                
                if feasible:
                    if self.out_all['feasibility_point'] == None:
                        self.out_all['feasibility_point'] = t
                    lam_prev = lam_curr
                    fpr_est_prev = fpr_est_curr

                    # change detection 
                    if self.fpr_est(lam_prev, out['s_ood'][-self.window_est:], out['i_ood'][-self.window_est:]) + psi > self.alpha:
                        num_ood_curr = 0
                        num_update = 0
                        lam_prev = np.inf
                        for key in out:
                            if key != 'instances':
                                out[key] = []
                        print(f'Shift Detected at {t}')


            # compute system fpr and tpr
            self.out_all['true_label'].append(y)
            self.out_all['pred_label'].append(y_hat)
            if y == 0:
                if y_hat == 0:
                    self.out_all['instances'].append('tn')
                elif y_hat == 1:
                    self.out_all['instances'].append('fp')
            elif y == 1:
                if y_hat == 0:
                    self.out_all['instances'].append('fn')
                elif y_hat == 1:
                    self.out_all['instances'].append('tp')
            
            # keep record
            self.out_all['expected_fpr'].append(self.stream.fpr(lam_prev) * 100)
            self.out_all['expected_tpr'].append(self.stream.tpr(lam_prev) * 100)
            self.out_all['threshold'].append(lam_prev)
            self.out_all['fpr_estimate'].append(fpr_est_prev * 100)
            
            # update g 
            if self.num_update_max > num_update and num_ood_curr >= self.update_freq[num_update]:

                num_ood_curr = 0
                num_update += 1

                if self.g != None:
                    assert num_update > 1
                    s1_old_lst = compute_scores(self.g, self.z_id_train, self.device)

                # (1) optimize for g and lambda simultaneously
                self.g, _, _ = train_g(
                    self.device,
                    self.method_id, self.g, self.mode_estimator,
                    self.training_params, self.beta, self.c, self.p, self.input_size, 
                    out['z_ood'][-self.window_est:], self.z_id_train, 
                    out['i_ood'], 
                    num_epoch=self.num_epoch, batch_size=self.batch_size, 
                    show_log=False, seed=self.seed)
                

                # (2) find safe threshold
                s0_new_lst = compute_scores(self.g, out['z_ood'][-self.window_est:], self.device)
                s1_new_lst = compute_scores(self.g, self.z_id_train, self.device)
                self.min_lam = np.min(s0_new_lst) - 0.1
                self.max_lam = np.max(s1_new_lst) + 0.1

                if self.ucb == 'heuristic_new':
                    psi = self.ucb_heuristic(out['i_ood'][-self.window_est:])
                if self.ucb == 'theory':
                    psi = self.ucb_theory(out['i_ood'][-self.window_est:])
                lam_curr, feasible, fpr_est_curr = self.binary_search(psi, s0_new_lst, out['i_ood'][-self.window_est:], np.inf)
                
                # (3) model selection based on tpr estimation
                if feasible:
                    if num_update > 1:
                        old_tpr_est = self.tpr_est_iid(lam=lam_prev, s1_lst=s1_old_lst)
                        new_tpr_est = self.tpr_est_iid(lam=lam_curr, s1_lst=s1_new_lst)
                    
                    if num_update == 1 or new_tpr_est + 2 * self.dkw_bound(n=len(self.z_id_train), delta=0.05) > old_tpr_est:

                        # update according to the new g
                        out['s_ood'][-self.window_est:] = s0_new_lst
                        lam_prev = lam_curr
                        self.stream.update_stream(self.g, self.device)
                        self.min_lam = self.stream.ood_quantile_score(q=0.5)
                        self.max_lam = self.stream.max_score()
                        self.out_all['update_points'].append(t)
                
                    else:
                        self.min_lam = self.stream.ood_quantile_score(q=0.5)
                        self.max_lam = self.stream.max_score()
                        self.out_all['update_points'].append(np.nan)
