import numpy as np
import SearchSpaces.Networks as n201
from Sampler import Sampler, run_sim_random
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from itertools import count
from tqdm import tqdm
import torch 
from torch import nn
from torch.distributions import Categorical, MultivariateNormal
import scipy
import argparse
from SearchSpaces import Metrics

parser = argparse.ArgumentParser(description='Make simulations')
parser.add_argument('data_set', choices=['CIFAR10', 'CIFAR100', 'ImageNet16'])
parser.add_argument('--length', type=int, default = 200)
parser.add_argument('--avgs', type=int, default = 500)

args = parser.parse_args()
exp_name = f"201_{args.data_set}_length_{args.length}_avgs_{args.avgs}_update"
print(exp_name)

ops = n201.NAS_BENCH_201
op_idx = dict([(op, idx) for idx, op in enumerate(ops)])
num_nodes = 6     
num_ops = len(ops)

def hash_vec(vec):
    assert np.all(np.array(vec) < num_ops)
    assert len(vec) == num_nodes
    
    return ((len(ops) ** np.arange(len(vec))) * vec).sum()

def vec_from_idx_slow(idx):
    return np.array([op_idx[item[0]] for sublist in n201.Structure.str2structure(n201.b201_configs[idx]).nodes for item in sublist])

vec_from_idx_lookup = np.array([vec_from_idx_slow(idx) for idx in range(len(n201.b201_configs))])
idx_from_vec_lookup = dict([(hash_vec(v),idx) for idx,v in enumerate(vec_from_idx_lookup)])

def vec_from_idx(idx):
    return vec_from_idx_lookup[idx]

def idx_from_vec(vec):
    return idx_from_vec_lookup[hash_vec(vec)]


def run_de(sampler, embs, bounds):
    def sample(emb):
        idx = np.argmin((embs - np.array(emb)).norm(dim = 1))
        return -sampler.sample(idx.item(), only_unique = False)

    with sampler:
        scipy.optimize.differential_evolution(sample, bounds, maxiter = 20000, tol=0.00)
            
def run_hyperopt(sampler, embs, bounds):
    from hyperopt import fmin, tpe, hp

    def sample(emb):
        idx = np.argmin((embs - np.array(emb)).norm(dim = 1))
        return -sampler.sample(idx.item(), only_unique = False)

    with sampler:
        for _ in count():
            fmin(sample, space=[hp.uniform(str(dim), b[0], b[1]) for dim, b in enumerate(bounds)], algo=tpe.suggest, max_evals=2000) # we automatically kill it so max_evals is irrelevant.

class EMA(object):
    def __init__(self, momentum):
        self._numerator = 0
        self._denominator = 0
        self._momentum = momentum

    def update(self, value):
        self._numerator = (
            self._momentum * self._numerator + (1 - self._momentum) * value
        )
        self._denominator = self._momentum * self._denominator + (1 - self._momentum) # bias correction

    def value(self):
        return self._numerator / self._denominator
            
def run_sim_reinforce_disc(sampler, lr = 1e-2, momentum = 0.9):
    class Policy(nn.Module):
        def __init__(self, num_nodes, num_ops):
            super(Policy, self).__init__()
            self.params = nn.Parameter(torch.zeros(num_nodes, num_ops))

        def forward(self):
            m = Categorical(nn.functional.softmax(self.params, dim=-1))
            action = m.sample()
            return m.log_prob(action), action.cpu().tolist()
    
    policy = Policy(num_nodes, num_ops)
    baseline = EMA(momentum)
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)

    with sampler:
        for _ in count():
            for _ in range(30): # try to sample a new arch with reinforce
                log_prob, vec = policy()
                reward = sampler.sample(idx_from_vec(vec), only_unique = True)
                if reward != None: break
            
            if reward == None: 
                # reinforce has converged. Eg. we tried 30 times to find a new (unseen) arch but we failed. 
                # do random sampling until we're done 
                for _ in count():
                    sampler.sample(np.random.randint(len(sampler)))
                
            baseline.update(reward)
            policy_loss = (-log_prob * (reward - baseline.value())).sum()
            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()

sampler = Sampler(Metrics.get_metrics("201", args.data_set), args.length)

embs_cena = np.load("embs/cena_201.npy")
embs_cena = torch.tensor(embs_cena).to(torch.float32)

embs_a2v = np.load("embs/a2v_201.npy")
embs_a2v = torch.tensor(embs_a2v)

bounds_a2v  = torch.stack((embs_a2v.min(axis=0)[0], embs_a2v.max(axis=0)[0])).T.numpy()
bounds_cena = torch.stack((torch.zeros(embs_cena.shape[-1]), torch.ones(embs_cena.shape[-1]))).T.numpy()

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

sampler.reset(label='Random Search')
[run_sim_random(sampler) for _ in tqdm(range(2000))];

sampler.reset(label='REINFORCE')
[run_sim_reinforce_disc(sampler, 3e-2) for _ in tqdm(range(args.avgs))];

sampler.reset(label="CENA -- TPE", color = colors[3])
for _ in tqdm(range(args.avgs)): run_hyperopt(sampler, embs_cena, bounds_cena)

sampler.reset(label="CENA -- DE", color = colors[4])
for _ in tqdm(range(args.avgs)): run_de(sampler, embs_cena, bounds_cena)

sampler.reset(label="A2V -- TPE", color = colors[3], dashes = [2,2])
for _ in tqdm(range(args.avgs)): run_hyperopt(sampler, embs_a2v, bounds_a2v)

sampler.reset(label="A2V -- DE", color = colors[4], dashes = [2,2])
for _ in tqdm(range(args.avgs)): run_de(sampler, embs_a2v, bounds_a2v)

sampler.plot_regrets()
plt.ylim(None, 6)
plt.xlabel("Number Evaluated Architechtures")
plt.ylabel("Regret (Best Arch - Best Found Arch)")
plt.show()


