import numpy as np
from Sampler import Sampler, run_sim_random
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from itertools import count
from tqdm import tqdm
import torch 
from torch import nn
from torch.distributions import Categorical, MultivariateNormal
import scipy
import argparse
from SearchSpaces import Metrics

parser = argparse.ArgumentParser(description='Make simulations')
parser.add_argument('--length', type=int, default = 200)
parser.add_argument('--avgs', type=int, default = 500)

args = parser.parse_args()
exp_name = f"101_length_{args.length}_avgs_{args.avgs}_update"
print(exp_name)

def run_de(sampler, embs, bounds):
    def sample(emb):
        idx = np.argmin((embs - np.array(emb)).norm(dim = 1))
        return -sampler.sample(idx.item(), only_unique = False)

    with sampler:
        scipy.optimize.differential_evolution(sample, bounds, maxiter = 20000, tol=0.00)
            
def run_hyperopt(sampler, embs, bounds):
    from hyperopt import fmin, tpe, hp

    def sample(emb):
        idx = np.argmin((embs - np.array(emb)).norm(dim = 1))
        return -sampler.sample(idx.item(), only_unique = False)

    with sampler:
        for _ in count():
            fmin(sample, space=[hp.uniform(str(dim), b[0], b[1]) for dim, b in enumerate(bounds)], algo=tpe.suggest, max_evals=2000) # we automatically kill it so max_evals is irrelevant.

class EMA(object):
    def __init__(self, momentum):
        self._numerator = 0
        self._denominator = 0
        self._momentum = momentum

    def update(self, value):
        self._numerator = (
            self._momentum * self._numerator + (1 - self._momentum) * value
        )
        self._denominator = self._momentum * self._denominator + (1 - self._momentum) # bias correction

    def value(self):
        return self._numerator / self._denominator
            

def run_sim_reinforce_disc(sampler, lr = 1e-2, momentum = 0.9):
    class Policy(nn.Module):
        def __init__(self, bounds):
            super(Policy, self).__init__()
            self.params = nn.Parameter(torch.zeros((len(bounds), bounds.max())))
            # using  bounds.max(), so create a mask for valid values 
            self.mask = (torch.arange(self.params.shape[-1])[None,:] >= torch.tensor(bounds[:,None]))
            
        def forward(self):
            with torch.no_grad():
                self.params[self.mask] = -np.inf #apply the mask for valid values.
            m = Categorical(nn.functional.softmax(self.params, dim=-1))
            action = m.sample()
            return m.log_prob(action), action.cpu().tolist()
    
    policy = Policy(bounds)
    baseline = EMA(momentum)
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)

    #unfortionately the data for idx_from_vec takes to much space to fit in the supplimentary :(
    idx_from_vec = lambda x: 0 

    with sampler:
        for _ in count():
            for _ in range(30): # try to sample a new arch with reinforce
                log_prob, vec = policy()
                reward = sampler.sample(idx_from_vec(vec), only_unique = True)
                if reward != None: break
            
            if reward == None: 
                # reinforce has converged. Eg. we tried 30 times to find a new (unseen) arch but we failed. 
                # do random sampling until we're done 
                for _ in count():
                    sampler.sample(np.random.randint(len(sampler)))

            baseline.update(reward)
            policy_loss = (-log_prob * (reward - baseline.value())).sum()
            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()

sampler = Sampler(Metrics.get_metrics("101", "CIFAR10")[::25], args.length)

embs_cena = np.load("embs/cena_101_subset.npy")
embs_cena = torch.tensor(embs_cena).to(torch.float32)

embs_a2v = np.load("embs/a2v_101.npy")[::25]
embs_a2v = torch.tensor(embs_a2v)

bounds_a2v  = torch.stack((embs_a2v.min(axis=0)[0], embs_a2v.max(axis=0)[0])).T.numpy()
bounds_cena = torch.stack((torch.zeros(embs_cena.shape[-1]), torch.ones(embs_cena.shape[-1]))).T.numpy()

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

sampler.reset(label='Random Search')
[run_sim_random(sampler) for _ in tqdm(range(2000))];

sampler.reset(label="CENA -- TPE", color = colors[3])
for _ in tqdm(range(args.avgs)): run_hyperopt(sampler, embs_cena, bounds_cena)

sampler.reset(label="CENA -- DE", color = colors[4])
for _ in tqdm(range(args.avgs)): run_de(sampler, embs_cena, bounds_cena)

sampler.reset(label="A2V -- TPE", color = colors[3], dashes = [2,2])
for _ in tqdm(range(args.avgs)): run_hyperopt(sampler, embs_a2v, bounds_a2v)

sampler.reset(label="A2V -- DE", color = colors[4], dashes = [2,2])
for _ in tqdm(range(args.avgs)): run_de(sampler, embs_a2v, bounds_a2v)

sampler.plot_regrets()
plt.ylim(None, 0.03)
plt.xlabel("Number Evaluated Architechtures")
plt.ylabel("Regret (Best Arch - Best Found Arch)")
plt.show()


