from time import time
import params
import numpy as np
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from DataReadFunctions import read_boone, read_diabetes, read_adversarial, read_forest_cover, read_higgs
from DisjointMaj import DisjointMaj
from BaggingMaj import BaggingMaj
from OverlapMaj import OverlapMaj
from AdversarialLearner import AdversarialLearner
import json
import random
import argparse

def experiment(classifier, X, y, X_test, y_test, adversarial=False):
    start_time = time()
    classifier.fit(X, y)
    end_time = time()
    execution_time = end_time - start_time
    sample_score = classifier.score(X, y)
    test_score = classifier.score(X_test, y_test)        
    d = {'fit_time': execution_time, 'sample_score': sample_score, 'test_score': test_score}

    print("Fitting time:", execution_time, "seconds")
    print("Sample accuracy:", sample_score)
    print("Test accuracy:", test_score)
    if adversarial:
        not_X = np.setdiff1d(np.array(range(params.u)), X.flatten())[:, np.newaxis]
        not_y = np.array([1]*not_X.size)
        outside_sample_score = classifier.score(not_X, not_y)
        print("Outside sample accuracy:", outside_sample_score)
        d |= {'outside_sample_score': outside_sample_score}
    print()
    return d

def test_pure_ada(X, y, X_test, y_test, weak_classifier=None, adversarial=False):
    print("Pure Adaboost:")
    ada = AdaBoostClassifier(n_estimators=params.n_rounds, algorithm='SAMME', estimator=weak_classifier)
    return experiment(ada, X, y, X_test, y_test, adversarial=adversarial) | {'n_voting_classifiers': 1}

def test_disjoint_maj(X, y, X_test, y_test, n_voting_classifiers, weak_classifier=None, adversarial=False):
    print(f"Disjoint Majority {n_voting_classifiers}:")
    disjoint = DisjointMaj(n_voting_classifiers=n_voting_classifiers, n_weak_classifiers=params.n_rounds,
                             weak_classifier=weak_classifier)
    return experiment(disjoint, X, y, X_test, y_test, adversarial=adversarial) | {'n_voting_classifiers': n_voting_classifiers}

def test_bagging_maj(X, y, X_test, y_test, n_voting_classifiers, weak_classifier=None, adversarial=False):
    print(f"Bagging Majority {n_voting_classifiers}:")
    bagging = BaggingMaj(n_voting_classifiers=n_voting_classifiers, n_weak_classifiers=params.n_rounds,
                         weak_classifier=weak_classifier)
    return experiment(bagging, X, y, X_test, y_test, adversarial=adversarial) | {'n_voting_classifiers': n_voting_classifiers}

def test_overlap_maj(X, y, X_test, y_test, n_voting_classifiers=None, weak_classifier=None, adversarial=False):
    print(f"Overlap Majority {n_voting_classifiers}:")
    overlap = OverlapMaj(n_voting_classifiers=n_voting_classifiers, n_weak_classifiers=params.n_rounds, weak_classifier=weak_classifier)
    return experiment(overlap, X, y, X_test, y_test, adversarial=adversarial) | {'n_voting_classifiers': len(overlap.hypotheses)}

def test_all(X, y, X_test, y_test, weak_classifier=None):
    d = {'n_rounds': params.n_rounds, 'n_training_samples': len(X), 'n_test_samples': len(X_test)}
    d['Pure Adaboost'] = test_pure_ada(X, y, X_test, y_test, weak_classifier)
    d['Disjoint Majority'] = []
    d['Bagging Majority'] = []
    d['Overlap Majority'] = []
    for i in [3, 5, 11, 15, 21, 29]:
        d['Disjoint Majority'].append(test_disjoint_maj(X, y, X_test, y_test, i, weak_classifier))
        d['Bagging Majority'].append(test_bagging_maj(X, y, X_test, y_test, i, weak_classifier))
        d['Overlap Majority'].append(test_overlap_maj(X, y, X_test, y_test, i, weak_classifier))
    return d

def test_all_full_overlap(X, y, X_test, y_test, weak_classifier=None, adversarial=False):
    d = {'n_rounds': params.n_rounds, 'n_training_samples': len(X), 'n_test_samples': len(X_test)}
    if adversarial:
        d |= {'universe_size': params.u, 'VC-dimension': params.VC, 'gamma': params.gamma, 't': params.t, 's': params.s}
    d['Pure Adaboost'] = test_pure_ada(X, y, X_test, y_test, weak_classifier, adversarial=adversarial)
    d['Disjoint Majority'] = []
    d['Bagging Majority'] = []
    for i in [3, 5, 11, 15, 21, 29]:
        d['Disjoint Majority'].append(test_disjoint_maj(X, y, X_test, y_test, i, weak_classifier, adversarial=adversarial))
        d['Bagging Majority'].append(test_bagging_maj(X, y, X_test, y_test, i, weak_classifier, adversarial=adversarial))
    d['Overlap Majority'] = test_overlap_maj(X, y, X_test, y_test, weak_classifier=weak_classifier, adversarial=adversarial)
    return d

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        prog="python experiment.py",
        description="Run experiments on a dataset",
    )
    parser.add_argument("dataset", type=str, help="Dataset to run the experiment on (one of \"boone\", \"diabetes\", \"forest_cover\", \"higgs\", \"adversarial\")")
    parser.add_argument("--seed", type=int, help="Random seed (default runs on seeds 1-5)")
    args = parser.parse_args()
    if args.dataset not in ['boone', 'diabetes', 'forest_cover', 'higgs', 'adversarial']:
        print("Invalid dataset:", args.dataset)
        exit(1)
    
    for seed in range(1, 6):
        if args.seed is not None:
            seed = args.seed
        random.seed(seed)
        np.random.seed(seed)
        print(f"Seed: {seed}")
        if args.dataset == 'boone':
            X, y, X_test, y_test = read_boone()
            d = test_all(X, y, X_test, y_test)
        elif args.dataset == 'higgs':
            X, y, X_test, y_test = read_higgs(num_of_samples=300000)
            d = test_all(X, y, X_test, y_test)
        elif args.dataset == 'forest_cover':
            X, y, X_test, y_test = read_forest_cover()
            d = test_all(X, y, X_test, y_test)
        elif args.dataset == 'diabetes':
            X, y, X_test, y_test = read_diabetes()
            d = test_all_full_overlap(X, y, X_test, y_test)
        elif args.dataset == 'adversarial':
            X, y, X_test, y_test = read_adversarial(params.u, params.m)
            d = test_all_full_overlap(X, y, X_test, y_test, weak_classifier=AdversarialLearner(), adversarial=True)
        d['random_seed'] = seed
        with open(f"results/{args.dataset}{seed}.json", 'w') as f:
            json.dump(d, f, indent=4)
        if args.seed is not None:
            break # Only run once if seed is specified