import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))

import argparse
import numpy as np

from pathlib import Path

import smfg
import smfg.problems
import smfg.vi as vi
from smfg.agents import AgentPAFullFeedback, AgentPABanditFeedback
from smfg.experiments import ExperimentPMABFullFeedback, ExperimentPMABBanditFeedback

import dill as pickle

parser = argparse.ArgumentParser(
                    description='Run an experiment of SMFG.')

parser.add_argument('-N', required=True, type=int, help="Number of agents to simulate.")     
# parser.add_argument('-K', required=True, type=int, help="Number of arms")      
parser.add_argument('--utd_data', required=True, type=str, help="UTD data file to load.") 
parser.add_argument("-f", '--feedback', required=True, type=str, help="Feedback type. (bandit or full)") 
parser.add_argument("-o", '--output', required=True, type=str, help="Output directory.")
parser.add_argument("-e", '--epoch', required=True, type=int, help="Number of epochs.")
parser.add_argument('--sigma', required=True, type=float, help="Noise level.")

parser.add_argument("-s", '--seed', default=1, type=int, help="Random seed.") 

parser.add_argument('--epsilon', type=str, default="auto", help="Exploration probability. (by default: scales with number of agents)")
parser.add_argument('-t', "--tau", default="auto", type=str, help="Regularization parameter.") 

args = parser.parse_args()

print(f"Output directory: {args.output}.")

if not os.path.exists(args.output):
   # Create a new directory because it does not exist
   os.makedirs(args.output)

i = 0
output_file = os.path.join(args.output, f"run{i}.pkl")
while os.path.exists(output_file):
    i += 1
    output_file = os.path.join(args.output, f"run{i}.pkl")

Path(output_file).touch()


### initialize experiment
N = int(args.N)

if args.tau == "auto":
    print("Auto determining tau")
    tau = 2 / np.cbrt(args.N)
else:
    tau = float(args.tau)

problem = populationbandits.problems.load_utd_problem(args.utd_data)
K = len(problem.problem_parameters["data"]["fitted_models"])

print(f"Loaded UTD models, K={K}.")

# eta0 = 0.5 / tau
if args.feedback == "full":
    agents = [AgentPAFullFeedback(K = K, tau=tau) for n in range(N)]
    exp = ExperimentPMABFullFeedback(agents=agents, K=K, operator=problem, sigma=args.sigma)
elif args.feedback == "bandit":
    if args.epsilon == "auto":
        print("Auto determining epsilon")
        epsilon = 2 / np.sqrt(N)
    else:
        epsilon = float(args.epsilon)

    # if args.tlength == "auto":
    #     tlength = int(np.log(N*T))
    # else:
    #     tlength = int(args.tlength)
        
    agents = [AgentPABanditFeedback( K = K, tau = tau, epsilon=epsilon) for n in range(N)]
    exp = ExperimentPMABBanditFeedback(agents=agents, K=K, operator=problem, sigma=args.sigma)
else:
    raise Exception("Unknown feedback.")


### solve VI

min_gap = 100000
best_sol = None
for lr in [0.001, 0.01, 0.1, 1.]:
    sol, epsilons, return_vals, x_path = vi.solve_vi(K=K, operator=problem, iterations=5000, eta=lr, return_path=True)
    gap = vi.compute_strong_gap(problem, x_path[-1])
    if gap <= min_gap:
        min_gap = gap
        best_sol = sol
print(f"Found solution with strong gap {min_gap}.")
sol = best_sol

### start PMAB experiment
seed = args.seed
np.random.seed(args.seed)

logs = []

policy_deviations = []
exploitabilities = []
l2_dists = []

for e in range(args.epoch):
    if e % 100 == 0:
        print(f"Epoch {e}")


    actions, mean_rewards, feedbacks = exp.run_epoch()

    action_freq = np.bincount(actions, minlength=K) / N
    
    mean_policy = np.mean([a.policy for a in agents], axis=0)
    policy_dev = np.mean([np.abs(a.policy - mean_policy).sum() for a in agents])
    
    max_exploitability = np.max( [  np.max(mean_rewards) - np.dot(a.policy, mean_rewards) for a in agents])
    
    l2_dist = np.mean([np.sqrt(((a.policy - sol) * (a.policy - sol) ).sum()) for a in agents])

    logs.append({
        "policy_deviation": policy_dev,
        "action_freq": action_freq,
        "mean_rewards": mean_rewards,
        "mean_policy": mean_policy,
        "max_exploitability": max_exploitability,
        "agent1_policy": np.copy(agents[0].policy),
        "mean_l2_dist": l2_dist,
    })

    if (e+1) % 200000 == 0:
        state_dict = {
                        "experiment": exp,
                        "sol": sol,
                        "seed": seed,
                        "utd_data_file": args.utd_data,
                        "logs": logs
                    }
        with open(output_file, 'wb') as f:  # open a text file
            pickle.dump(state_dict, f) # serialize the list


print(f"Experiment done, saving...")
state_dict = {
                "experiment": exp,
                "sol": sol,
                "seed": seed,
                "utd_data_file": args.utd_data,
                "logs": logs
            }
with open(output_file, 'wb') as f:  # open a text file
    pickle.dump(state_dict, f) # serialize the list