import time
import argparse
import numpy as np
import matplotlib.pyplot as plt

from Algorithms.DataDropOja import DataDropOja
from Algorithms.ExperienceReplayOja import ExperienceReplayOja
from Algorithms.Offline import Offline
from Algorithms.VanillaOja import VanillaOja
from datetime import datetime
from MarkovChains.Cyclic import generate_cyclic_mc
from MarkovChains.FaultyServer import generate_faulty_server_mc
from MarkovChains.WorstCase1 import generate_worstcase1_mc
from MarkovChains.WorstCase2 import generate_worstcase2_mc
from MarkovChains.ErdosRenyi import generate_erdosrenyi_mc
from plot_experiments_eigengap_varying import plot_and_save
from numpy import linalg


def generate_markov_chain_data(mc, num_repetitions, is_faulty_server=False, offset=0):
    # Generate Markov Chain data
    repetitions = {}
    if not is_faulty_server:
        for r in range(num_repetitions):
            data = []
            state = mc.get_initial_state()
            for i in range(args.timesteps):
                if i >= offset:
                    sample = mc.get_sample(state)
                    data.append((i, state, sample))
                state = mc.get_next_state(state)
            repetitions[r] = data
    else:
        for r in range(num_repetitions):
            data = []
            state = np.array([0])
            prev_value = -1
            for i in range(args.timesteps):
                if int(state) == 0:
                    sample = mc.get_sample(state)
                    if i >= offset:
                        data.append((i, state, sample))
                    prev_value = sample
                else:
                    if i >= offset:
                        data.append((i, state, prev_value))
                state = mc.get_next_state(state)
            repetitions[r] = data
    return repetitions


def generate_markov_chain_data_drop_data(mc, num_repetitions, is_faulty_server=False, offset=0):
    # Generate Markov Chain data
    repetitions = {}
    if not is_faulty_server:
        for r in range(num_repetitions):
            data = []
            state = mc.get_initial_state()
            for i in range(args.timesteps * args.drop_number):
                if i % args.drop_number == 0:
                    sample = mc.get_sample(state)
                    if i >= offset:
                        data.append((i, state, sample))
                state = mc.get_next_state(state)
            repetitions[r] = data
    else:
        for r in range(num_repetitions):
            data = []
            state = np.array([0])
            prev_value = -1
            for i in range(args.timesteps * args.drop_number):
                if int(state) == 0:
                    sample = mc.get_sample(state)
                    prev_value = sample
                    if i % args.drop_number == 0:
                        if i >= offset:
                            data.append((i, state, sample))
                else:
                    if i % args.drop_number == 0:
                        if i >= offset:
                            data.append((i, state, prev_value))
                state = mc.get_next_state(state)
            repetitions[r] = data
    return repetitions


def generate_iid_data(mc, num_repetitions, offset=0):
    # Generate data for IID simulations
    iid_repetitions = {}
    for r in range(num_repetitions):
        data = []
        state = mc.get_initial_state()
        for i in range(args.timesteps):
            sample = mc.get_sample(state)
            state = np.random.choice(np.arange(mc.num_states), 1, p=mc.get_stationary_distribution())
            if i >= offset:
                data.append((i, state, sample))
        iid_repetitions[r] = data
    return iid_repetitions


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments for Markov-Oja simulations.')
    parser.add_argument('--seed', type=int, default=5,
                        help='The random seed for all data generation and simulations.')
    parser.add_argument('--timesteps', type=int, default=5000, help='Number of timesteps for simulation.')
    parser.add_argument('--num_repetitions', type=int, default=20,
                        help='Number of repetitions over which to average the results.')
    parser.add_argument('--num_states', type=int, default=50, help='Number of states of Markov chain.')
    parser.add_argument('--num_dimensions', default=1000, type=int,
                        help='Number of dimensions of the data being generated at each state.')
    parser.add_argument('--markov_chain', default='erdosrenyi', help='Type of Markov Chain to use.')
    parser.add_argument('--experiment_identifier', default='standard', help='Identifying name of the experiment.')
    parser.add_argument('--output_dir', default='results', help='Directory for saving experiment results.')
    parser.add_argument('--cov_eigengap_threshold', type=int, default=60,
                        help='Threshold for eigengap of True Covariance '
                             'Matrix, Used for generating Cyclic Markov Chains.')
    parser.add_argument('--drop_number', type=int, default=10, help='Drop Number for Data Drop Algorithm.')
    parser.add_argument('--buffer_size', type=int, default=1000, help='Buffer size for experience replay.')
    parser.add_argument('--buffer_drop_number', type=int, default=100,
                        help='Drop Number for dropping data at beginning of each buffer.')
    parser.add_argument('--algorithms', nargs='+', help='Algorithms to use.', default="worstcase2_eps_0.5 "
                                                                                      "worstcase2_eps_0.1 "
                                                                                      "worstcase2_eps_0.01 "
                                                                                      "worstcase2_eps_0.001 ")
    parser.add_argument('--lr_multiplier', type=float, default=0.005, help='Multiplier for the learning rate.')
    parser.add_argument('--lr_decay', default=False, action='store_true',
                        help='Use this flag to enable decay of learning rate.')
    parser.add_argument('--offset', type=int, default=0, help='Offset after which to consider data.')
    parser.add_argument('--data_drop_optimize', default=False, action='store_true',
                        help='Use this flag to optimize data generation for data-drop over large number of timesteps.')
    args = parser.parse_args()
    print("Arguments : ", args)

    # Set random seed
    np.random.seed(args.seed)

    # Set experiment name
    current_timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
    experiment_name = args.experiment_identifier + "-" + current_timestamp + "-" + str(args.seed)

    # Initialise Markov Chain
    markov_chain_generator_fn = {'cyclic': generate_cyclic_mc,
                                 'faulty_server': generate_faulty_server_mc,
                                 'worstcase1': generate_worstcase1_mc,
                                 'worstcase2': generate_worstcase2_mc,
                                 'erdosrenyi': generate_erdosrenyi_mc}
    markov_chain_generator_args = {'cyclic': {'num_states': args.num_states, 'num_dimensions': args.num_dimensions,
                                              'seed': args.seed, 'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'faulty_server': {'num_states': args.num_states,
                                                     'num_dimensions': args.num_dimensions,
                                                     'seed': args.seed},
                                   'worstcase1': {'num_states': args.num_states, 'num_dimensions': args.num_dimensions,
                                                  'seed': args.seed,
                                                  'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'worstcase2_eps_0.99': {'eps': 0.99, 'beta': 1.0, 'num_states': args.num_states,
                                                          'num_dimensions': args.num_dimensions, 'seed': args.seed,
                                                          'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'worstcase2_eps_0.9': {'eps': 0.9, 'beta': 1.0, 'num_states': args.num_states,
                                                          'num_dimensions': args.num_dimensions, 'seed': args.seed,
                                                          'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'worstcase2_eps_0.5': {'eps': 0.5, 'beta': 1.0, 'num_states': args.num_states,
                                                          'num_dimensions': args.num_dimensions, 'seed': args.seed,
                                                          'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'worstcase2_eps_0.1': {'eps': 0.005, 'beta': 1.0, 'num_states': args.num_states,
                                                          'num_dimensions': args.num_dimensions, 'seed': args.seed,
                                                          'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'worstcase2_eps_0.01': {'eps': 0.001, 'beta': 1.0, 'num_states': args.num_states,
                                                           'num_dimensions': args.num_dimensions, 'seed': args.seed,
                                                           'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'worstcase2_eps_0.001': {'eps': 0.0005, 'beta': 1.0, 'num_states': args.num_states,
                                                           'num_dimensions': args.num_dimensions, 'seed': args.seed,
                                                           'cov_eigengap_threshold': args.cov_eigengap_threshold},
                                   'erdosrenyi': {'num_states': args.num_states, 'num_dimensions': args.num_dimensions,
                                                  'seed': args.seed,
                                                  'cov_eigengap_threshold': args.cov_eigengap_threshold,
                                                  'p': 2 * np.log(args.num_states) / args.num_states}}

    markov_chain1 = markov_chain_generator_fn['worstcase2'](**markov_chain_generator_args['worstcase2_eps_0.99'])
    markov_chain2 = markov_chain_generator_fn['worstcase2'](**markov_chain_generator_args['worstcase2_eps_0.9'])
    markov_chain3 = markov_chain_generator_fn['worstcase2'](**markov_chain_generator_args['worstcase2_eps_0.5'])
    markov_chain4 = markov_chain_generator_fn['worstcase2'](**markov_chain_generator_args['worstcase2_eps_0.1'])
    markov_chain5 = markov_chain_generator_fn['worstcase2'](**markov_chain_generator_args['worstcase2_eps_0.01'])
    markov_chain6 = markov_chain_generator_fn['worstcase2'](**markov_chain_generator_args['worstcase2_eps_0.001'])
    markov_chains = {'worstcase2_eps_0.99': markov_chain1,
                     'worstcase2_eps_0.9': markov_chain2,
                     'worstcase2_eps_0.5': markov_chain3,
                     'worstcase2_eps_0.1': markov_chain4,
                     'worstcase2_eps_0.01': markov_chain5,
                     'worstcase2_eps_0.001': markov_chain6}

    # Generate Markov Chain data
    start = time.time()
    markov_chain_data1 = generate_markov_chain_data(markov_chain1, args.num_repetitions,
                                                    (args.markov_chain == 'faulty_server'),
                                                    args.offset)
    markov_chain_data2 = generate_markov_chain_data(markov_chain2, args.num_repetitions,
                                                    (args.markov_chain == 'faulty_server'),
                                                    args.offset)
    markov_chain_data3 = generate_markov_chain_data(markov_chain3, args.num_repetitions,
                                                    (args.markov_chain == 'faulty_server'),
                                                    args.offset)
    markov_chain_data4 = generate_markov_chain_data(markov_chain4, args.num_repetitions,
                                                    (args.markov_chain == 'faulty_server'),
                                                    args.offset)
    markov_chain_data5 = generate_markov_chain_data(markov_chain5, args.num_repetitions,
                                                    (args.markov_chain == 'faulty_server'),
                                                    args.offset)
    markov_chain_data6 = generate_markov_chain_data(markov_chain6, args.num_repetitions,
                                                    (args.markov_chain == 'faulty_server'),
                                                    args.offset)
    end = time.time()
    print("Time required to generate Markov Chain data : ", end - start)

    w_init = []
    for i in range(args.num_repetitions):
        w = np.random.randn(markov_chain1.num_dimensions)
        w /= linalg.norm(w)
        w_init.append(w)

    # Apply Algorithms
    algorithm_fn = {'worstcase2_eps_0.99': VanillaOja(),
                    'worstcase2_eps_0.9': VanillaOja(),
                    'worstcase2_eps_0.5': VanillaOja(),
                    'worstcase2_eps_0.1': VanillaOja(),
                    'worstcase2_eps_0.01': VanillaOja(),
                    'worstcase2_eps_0.001': VanillaOja()}

    algorithm_args = {'worstcase2_eps_0.99': {'data': markov_chain_data1, 'markov_chain': markov_chain1,
                                              'lr_multiplier': args.lr_multiplier, 'lr_decay': args.lr_decay,
                                              'w_init': w_init, 'is_iid': False},
                      'worstcase2_eps_0.9': {'data': markov_chain_data2, 'markov_chain': markov_chain2,
                                              'lr_multiplier': args.lr_multiplier, 'lr_decay': args.lr_decay,
                                              'w_init': w_init, 'is_iid': False},
                      'worstcase2_eps_0.5': {'data': markov_chain_data3, 'markov_chain': markov_chain3,
                                              'lr_multiplier': args.lr_multiplier, 'lr_decay': args.lr_decay,
                                              'w_init': w_init, 'is_iid': False},
                      'worstcase2_eps_0.1': {'data': markov_chain_data4, 'markov_chain': markov_chain4,
                                              'lr_multiplier': args.lr_multiplier, 'lr_decay': args.lr_decay,
                                              'w_init': w_init, 'is_iid': False},
                      'worstcase2_eps_0.01': {'data': markov_chain_data5, 'markov_chain': markov_chain5,
                                              'lr_multiplier': args.lr_multiplier, 'lr_decay': args.lr_decay,
                                              'w_init': w_init, 'is_iid': False},
                      'worstcase2_eps_0.001': {'data': markov_chain_data6, 'markov_chain': markov_chain6,
                                               'lr_multiplier': args.lr_multiplier, 'lr_decay': args.lr_decay,
                                               'w_init': w_init, 'is_iid': False}}
    results = {}
    valid_algorithms = args.algorithms
    print(valid_algorithms)
    for algorithm in algorithm_fn:
        if algorithm in valid_algorithms:
            print("Simulation for " + algorithm + " started.")
            start = time.time()
            results[algorithm] = algorithm_fn[algorithm].run_simulation(**algorithm_args[algorithm])
            end = time.time()
            print("Simulation for " + algorithm + " completed.")
            print("Time required : ", end - start)
            print("=================================================")

    # Plot and Save results
    plot_and_save(experiment_name, args.output_dir, results, markov_chains)