import argparse
import numpy as np
import torch
import yaml

import sys
import os
sys.path.append(os.path.abspath('../src/'))

import data_utils
import nn_utils
import plot_utils

from config import DataConfig, ExpConfigStatic, ExpConfigAdaptive
from data_utils import DataStream1, MergedStream
from estimator import StaticBaseline, ScoreFunctionThresholdEstimatorDistrShft

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def asat(args, config, seed):
    # create configs
    data_config1 = DataConfig(model=args.model_name, ood_name=args.ood1_name, id_name=args.id_name, num=config['shift_time'], gamma=config['gamma'], init_g=args.init_g)
    data_config2 = DataConfig(model=args.model_name, ood_name=args.ood2_name, id_name=args.id_name, num=config['num']-config['shift_time'], gamma=config['gamma'], init_g=args.init_g)

    exp_config = ExpConfigAdaptive(
        method_id=1, alpha=config['alpha'], delta=config['delta'], p=config['p'], ucb=config['ucb'], mode_estimator=config['estimation_mode'], 
        training_params=config['training_params'], beta=config['beta'], c=config['c'], input_size=config['input_size'], num_epoch=config['num_epoch'], 
        batch_size=config['batch_size'], update_freq=[100]*20+[500]*100, c_heuristic=config['c_heuristic'], window_est=config['window_est']
    )

    # generate data and estimator
    stream1 = DataStream1(data_config1, seed)
    stream2 = DataStream1(data_config2, seed)
    stream = MergedStream([stream1, stream2])
    estimator = ScoreFunctionThresholdEstimatorDistrShft(stream, exp_config, seed, device)
            
    # run the algorithm
    estimator.run()

    # evaluate
    window_est = int(config['window_est']/1000)
    fpr_lst = plot_utils.get_evalfpr_lst(estimator.out_all['instances'], window_est)
    tpr_lst = plot_utils.get_evaltpr_lst(estimator.out_all['instances'], window_est)
    return fpr_lst, tpr_lst

def fsat(args, config, seed):
    # create configs
    data_config1 = DataConfig(model=args.model_name, ood_name=args.ood1_name, id_name=args.id_name, num=config['shift_time'], gamma=config['gamma'], init_g=args.init_g)
    data_config2 = DataConfig(model=args.model_name, ood_name=args.ood2_name, id_name=args.id_name, num=config['num']-config['shift_time'], gamma=config['gamma'], init_g=args.init_g)

    exp_config = ExpConfigAdaptive(
        method_id=1, alpha=config['alpha'], delta=config['delta'], p=config['p'], ucb=config['ucb'], mode_estimator=config['estimation_mode'], 
        training_params=config['training_params'], beta=config['beta'], c=config['c'], input_size=config['input_size'], num_epoch=config['num_epoch'], 
        batch_size=config['batch_size'], update_freq=[], c_heuristic=config['c_heuristic'], window_est=config['window_est']
    )

    # generate data and estimator
    stream1 = DataStream1(data_config1, seed)
    stream2 = DataStream1(data_config2, seed)

    stream = MergedStream([stream1, stream2])
    estimator = ScoreFunctionThresholdEstimatorDistrShft(stream, exp_config, seed, device)
            
    # run the algorithm
    estimator.run()

    # evaluate
    window_est = int(config['window_est']/1000)
    fpr_lst = plot_utils.get_evalfpr_lst(estimator.out_all['instances'], window_est)
    tpr_lst = plot_utils.get_evaltpr_lst(estimator.out_all['instances'], window_est)
    return fpr_lst, tpr_lst

def fsft(args, config, seed):
    # create configs
    data_config1 = DataConfig(model=args.model_name, ood_name=args.ood1_name, id_name=args.id_name, num=config['shift_time'], gamma=config['gamma'], init_g=args.init_g)
    data_config2 = DataConfig(model=args.model_name, ood_name=args.ood2_name, id_name=args.id_name, num=config['num']-config['shift_time'], gamma=config['gamma'], init_g=args.init_g)

    exp_config = ExpConfigStatic(q=config['tpr'], metric='tpr')

    # generate data and estimator
    stream1 = DataStream1(data_config1, seed)
    stream2 = DataStream1(data_config2, seed)

    stream = MergedStream([stream1, stream2])
    estimator = StaticBaseline(stream, exp_config)
            
    # run the algorithm
    estimator.run()

    # evaluate
    window_est = int(config['window_est']/1000)
    fpr_lst = plot_utils.get_evalfpr_lst(estimator.out_all['instances'], window_est)
    tpr_lst = plot_utils.get_evaltpr_lst(estimator.out_all['instances'], window_est)
    return fpr_lst, tpr_lst

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run distribution shift settings with given args')
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--id_name', type=str, required=True)
    parser.add_argument('--ood1_name', type=str, required=True)
    parser.add_argument('--ood2_name', type=str, required=True)
    parser.add_argument('--init_g', type=str, required=True)
    parser.add_argument('--seed', type=int, required=True)
    
    # parse and pass
    args = parser.parse_args()
    
    # run experiments

    #### ASAT ####
    with open(f'../configs/distr_shft/asat.yaml', 'r') as file:
        config = yaml.safe_load(file)
        alpha = config['alpha']
        num = config['num']
    print('==== Starting ASAT ====\n')
    asat_fpr, asat_tpr = asat(args, config, seed=args.seed)

    #### FSAT ####
    with open(f'../configs/distr_shft/fsat.yaml', 'r') as file:
        config = yaml.safe_load(file)
    print('==== Starting FSAT ====\n')
    fsat_fpr, fsat_tpr = fsat(args, config, seed=args.seed)

    ### FSFT ####
    with open(f'../configs/distr_shft/fsft.yaml', 'r') as file:
        config = yaml.safe_load(file)
    print('==== Starting FSFT ====\n')
    fsft_fpr, fsft_tpr = fsft(args, config, seed=args.seed)

    # Plot results
    plot_utils.plot_stationary([asat_fpr, fsat_fpr, fsft_fpr], [asat_tpr, fsat_tpr, fsft_tpr], [], None, alpha, num, 'Stataionary Result', './result.png')