import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))

import argparse
import numpy as np

from pathlib import Path

import smfg
import smfg.problems
import smfg.vi as vi
from smfg.agents import AgentPAFullFeedback, AgentPABanditFeedback
from smfg.experiments import ExperimentPMABFullFeedback, ExperimentPMABBanditFeedback

import dill as pickle

parser = argparse.ArgumentParser(
                    description='Run an experiment of SMFG.')

parser.add_argument('-N', required=True, type=int, help="Number of agents to simulate.")     
parser.add_argument('-K', required=True, type=int, help="Number of arms")      
parser.add_argument("-p", '--problem_type', required=True, type=str, help="SMFG problem type to generate and simulate. ()") 
parser.add_argument("-f", '--feedback', required=True, type=str, help="Feedback type. (bandit or full)") 
parser.add_argument("-o", '--output', required=True, type=str, help="Output directory.")
parser.add_argument("-e", '--epoch', required=True, type=int, help="Number of epochs.")
parser.add_argument('--sigma', required=True, type=float, help="Noise level.")

parser.add_argument('--problemseed', default=1, type=int, help="Random seed for generating problem.") 
parser.add_argument("-s", '--seed', default=1, type=int, help="Random seed.") 

parser.add_argument('--epsilon', type=str, default="auto", help="Exploration probability. (by default: scales with number of agents)")
# parser.add_argument('--tlength', type=str, default="auto", help="Exploration episode length for bandits.")
parser.add_argument('-t', "--tau", default="auto", type=str, help="Regularization parameter.") 

args = parser.parse_args()

print(f"Generating PMAB of type {args.problem_type}, N={args.N}, K={args.K}, feedback={args.feedback}.")

print(f"Set seed to {args.problemseed} for generating problem.")
problemseed = args.problemseed

print(f"Output directory: {args.output}.")

if not os.path.exists(args.output):
   # Create a new directory because it does not exist
   os.makedirs(args.output)

i = 0
output_file = os.path.join(args.output, f"run{i}.pkl")
while os.path.exists(output_file):
    i += 1
    output_file = os.path.join(args.output, f"run{i}.pkl")

Path(output_file).touch()


### initialize experiment
K = int(args.K)
N = int(args.N)

if args.tau == "auto":
    print("Auto determining tau")
    tau = 2 / np.cbrt(args.N)
else:
    tau = float(args.tau)

np.random.seed(args.problemseed)
if args.problem_type == "linear":
    problem = populationbandits.problems.generate_new_linear_problem(K=K, seed=problemseed)
elif args.problem_type == "exp":
    problem = populationbandits.problems.generate_new_exp_problem(K=K, seed=problemseed)
elif args.problem_type == "kl":
    problem = populationbandits.problems.generate_new_kl_problem(K=K, seed=problemseed)
elif args.problem_type == "bb":
    problem = populationbandits.problems.generate_beach_bar_problem(K=K, seed=problemseed)
else:
    raise Exception("Problem type not known.")

# eta0 = 0.5 / tau
if args.feedback == "full":
    agents = [AgentPAFullFeedback(K = args.K, tau=tau) for n in range(N)]
    exp = ExperimentPMABFullFeedback(agents=agents, K=K, operator=problem, sigma=args.sigma)
elif args.feedback == "bandit":
    if args.epsilon == "auto":
        print("Auto determining epsilon")
        epsilon = 2 / np.sqrt(N)
    else:
        epsilon = float(args.epsilon)

    # if args.tlength == "auto":
    #     tlength = int(np.log(N*T))
    # else:
    #     tlength = int(args.tlength)
        
    agents = [AgentPABanditFeedback( K = args.K, tau = tau, epsilon=epsilon) for n in range(N)]
    exp = ExperimentPMABBanditFeedback(agents=agents, K=K, operator=problem, sigma=args.sigma)
else:
    raise Exception("Unknown feedback.")


### solve VI

min_gap = 100000
best_sol = None
for lr in [0.001, 0.01, 0.1, 1.]:
    sol, epsilons, return_vals, x_path = vi.solve_vi(K=args.K, operator=problem, iterations=5000, eta=lr, return_path=True)
    gap = vi.compute_strong_gap(problem, x_path[-1])
    if gap <= min_gap:
        min_gap = gap
        best_sol = sol
print(f"Found solution with strong gap {min_gap}.")
sol = best_sol

### start PMAB experiment
print(f"Set seed to {args.problemseed} for simulation.")
seed = args.seed
np.random.seed(args.seed)

logs = []

policy_deviations = []
exploitabilities = []
l2_dists = []

for e in range(args.epoch):
    if e % 100 == 0:
        print(f"Epoch {e}")


    actions, mean_rewards, feedbacks = exp.run_epoch()

    action_freq = np.bincount(actions, minlength=K) / N
    
    mean_policy = np.mean([a.policy for a in agents], axis=0)
    policy_dev = np.mean([np.abs(a.policy - mean_policy).sum() for a in agents])
    
    max_exploitability = np.max( [  np.max(mean_rewards) - np.dot(a.policy, mean_rewards) for a in agents])
    
    l2_dist = np.mean([np.sqrt(((a.policy - sol) * (a.policy - sol) ).sum()) for a in agents])

    logs.append({
        "policy_deviation": policy_dev,
        "action_freq": action_freq,
        "mean_rewards": mean_rewards,
        "mean_policy": mean_policy,
        "max_exploitability": max_exploitability,
        "agent1_policy": np.copy(agents[0].policy),
        "mean_l2_dist": l2_dist,
    })

    if (e+1) % 200000 == 0:
        state_dict = {
                        "experiment": exp,
                        "sol": sol,
                        "seed": seed,
                        "problem_seed": args.problemseed,
                        "logs": logs
                    }
        with open(output_file, 'wb') as f:  # open a text file
            pickle.dump(state_dict, f) # serialize the list


print(f"Experiment done, saving...")
state_dict = {
                "experiment": exp,
                "sol": sol,
                "seed": seed,
                "problem_seed": args.problemseed,
                "logs": logs
            }
with open(output_file, 'wb') as f:  # open a text file
    pickle.dump(state_dict, f) # serialize the list