import argparse
import numpy as np
import torch
import yaml

import sys
import os
sys.path.append(os.path.abspath('../src/'))

import data_utils
import nn_utils
import plot_utils

from config import DataConfig, ExpConfigStatic, ExpConfigAdaptive
from data_utils import DataStream1
from estimator import StaticBaseline, ScoreFunctionThresholdEstimator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def asat(args, config, seed):
    # create configs
    data_config = DataConfig(model=args.model_name, ood_name=args.ood_name, id_name=args.id_name, num=config['num'], gamma=config['gamma'], init_g=args.init_g)

    exp_config = ExpConfigAdaptive(
        method_id=1, alpha=config['alpha'], delta=config['delta'], p=config['p'], ucb=config['ucb'], mode_estimator=config['estimation_mode'], 
        training_params=config['training_params'], beta=config['beta'], c=config['c'], input_size=config['input_size'], num_epoch=config['num_epoch'], 
        batch_size=config['batch_size'], update_freq=[100]*20+[500]*20+[1000]*100, c_heuristic=config['c_heuristic'],
    )

    # generate data and estimator
    stream = DataStream1(data_config, seed)
    estimator = ScoreFunctionThresholdEstimator(stream, exp_config, seed, device)
            
    # run the algorithm
    estimator.run()

    # evaluate fpr and tpr
    fpr_lst = plot_utils.get_evalfpr_lst(estimator.out_all['instances'], estimator.n)
    tpr_lst = plot_utils.get_evaltpr_lst(estimator.out_all['instances'], estimator.n)
    return fpr_lst, tpr_lst

def fsat(args, config, seed):
    # create configs
    data_config = DataConfig(model=args.model_name, ood_name=args.ood_name, id_name=args.id_name, num=config['num'], gamma=config['gamma'], init_g=args.init_g)

    exp_config = ExpConfigAdaptive(
        method_id=1, alpha=config['alpha'], delta=config['delta'], p=config['p'], ucb=config['ucb'], mode_estimator=config['estimation_mode'], 
        training_params=config['training_params'], beta=config['beta'], c=config['c'], input_size=config['input_size'], num_epoch=config['num_epoch'], 
        batch_size=config['batch_size'], update_freq=[], c_heuristic=config['c_heuristic'],
    )

    # generate data and estimator
    stream = DataStream1(data_config, seed)
    estimator = ScoreFunctionThresholdEstimator(stream, exp_config, seed, device)
            
    # run the algorithm
    estimator.run()

    # evaluate fpr and tpr
    fpr_lst = plot_utils.get_evalfpr_lst(estimator.out_all['instances'], estimator.n)
    tpr_lst = plot_utils.get_evaltpr_lst(estimator.out_all['instances'], estimator.n)
    return fpr_lst, tpr_lst

def fsft(args, config, seed):
    # create configs
    data_config = DataConfig(model=args.model_name, ood_name=args.ood_name, id_name=args.id_name, num=config['num'], gamma=config['gamma'], init_g=args.init_g)

    exp_config = ExpConfigStatic(q=config['tpr'], metric='tpr')

    # generate data and estimator
    stream = DataStream1(data_config, seed)
    estimator = StaticBaseline(stream, exp_config)
            
    # run the algorithm
    estimator.run()

    # evaluate fpr and tpr
    fpr_lst = plot_utils.get_evalfpr_lst(estimator.out_all['instances'], estimator.n)
    tpr_lst = plot_utils.get_evaltpr_lst(estimator.out_all['instances'], estimator.n)
    return fpr_lst, tpr_lst

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run stataionary settings with given args')
    # parser.add_argument('--method', type=str, required=True)
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--id_name', type=str, required=True)
    parser.add_argument('--ood_name', type=str, required=True)
    parser.add_argument('--init_g', type=str, required=True)
    parser.add_argument('--seed', type=int, required=True)

    # parse and pass
    args = parser.parse_args()
    
    # run experiments

    #### ASAT ####
    with open(f'../configs/stationary/asat.yaml', 'r') as file:
        config = yaml.safe_load(file)
        alpha = config['alpha']
        num = config['num']
    print('==== Starting ASAT ====\n')
    asat_fpr, asat_tpr = asat(args, config, seed=args.seed)

    #### FSAT ####
    with open(f'../configs/stationary/fsat.yaml', 'r') as file:
        config = yaml.safe_load(file)
    print('==== Starting FSAT ====\n')
    fsat_fpr, fsat_tpr = fsat(args, config, seed=args.seed)

    ### FSFT ####
    with open(f'../configs/stationary/fsft.yaml', 'r') as file:
        config = yaml.safe_load(file)
    print('==== Starting FSFT ====\n')
    fsft_fpr, fsft_tpr = fsft(args, config, seed=args.seed)

    # Plot results
    plot_utils.plot_stationary([asat_fpr, fsat_fpr, fsft_fpr], [asat_tpr, fsat_tpr, fsft_tpr], [], None, alpha, num, 'Stataionary Result', './result.png')