import pickle, argparse, os
import numpy as np
from graphgym.pyspice_utils import simulation



def simulate_pop(graph, feats):
    out = []
    for i in range(len(feats)):
        ft = feats[i].tolist()
        g = graph.copy()
        try:
            g.vs['feat'] = ft
            sim = simulation(g, compute_features=True)
            g_rdm, ugw_rdm, pm_rdm = np.round(float(sim.gain[0] / 100), 3), np.round(float(sim.ugw / 1e9), 3), np.round(float(sim.pm / 90), 3)
            out.append((g_rdm, ugw_rdm, pm_rdm))
        except:
            out.append((0, 0, 0))
    return np.array(out)


def compute_fitness(sim_out, w, alpha=0.8):
    # The coef of each term in the fitness calculation is updated with the mean population fitness (the higher the fitness, the lower the coef). 
    scaled_sim_out = (sim_out / np.array([3.0, 30.0, 4.0])).clip(max=1.0)
    avg = scaled_sim_out.mean(axis=0)
    w = alpha * w + (1 - alpha) * (1 - avg)
    return (w * scaled_sim_out).mean(axis=1), w


def merge_feats(new_feats, n_new_samples, x_prob, mutation_prob):
    cross_over_outs = []
    for _ in range(n_new_samples):
        p1, p2 = np.random.choice(len(new_feats), 2, replace=False)
        ft1, ft2 = new_feats[p1], new_feats[p2]
        x_probs = np.random.rand(len(ft1))
        new_sample = ft1
        # Merge p2 into p1 with proba x_prob
        new_sample[x_probs < x_prob] = ft2[x_probs < x_prob]
        # Generate brand new feats with proba mutation_prob
        new_sample[x_probs < mutation_prob] = (np.random.rand((x_probs < mutation_prob).sum()) * 100 + 1).round(0)
        # new_sample[x_probs < mutation_prob] = np.round(np.random.rand((x_probs < mutation_prob).sum()) * 25 + 50, 0).clip(min=1.0, max=100.0)
        cross_over_outs.append(new_sample)
    return np.concatenate([new_feats, np.array(cross_over_outs)], axis=0)


def genetic_algo(graph, pop_size=50, prune_prop=0.1, max_gen=15, k=0.2, alpha=0.8, w=np.array([1.0, 1.0, 1.0]), x_prob=0.2):

    # Create population
    feats = np.round(np.random.randn(pop_size, len(graph.vs['type'])) * 25 + 50, 0).clip(min=1.0, max=100.0)
    n_keep = int(prune_prop * len(feats))
    # Compute scores
    sim_out = simulate_pop(graph, feats)
    fitness, w = compute_fitness(sim_out, w, alpha=alpha)

    besties = []

    for gen in range(max_gen):
        # Prune pop
        keep_idx = fitness.argsort()[::-1][:n_keep]
        new_feats = feats[keep_idx]
        mutation_prob = 0.01 # np.exp(-k * gen) # 0.1 # np.exp(-k * (max_gen - gen))
        # Merge
        feats = merge_feats(new_feats, pop_size - n_keep, x_prob, mutation_prob)
        # Compute scores
        sim_out = simulate_pop(graph, feats)
        fitness, w = compute_fitness(sim_out, w, alpha=alpha)

        best_idx = fitness.argmax()
        best_fitness, best_feats = fitness[best_idx], feats[best_idx]
        besties.append((best_fitness.round(4), best_feats))

    return best_feats.tolist() # best_fitness, best_feats, besties


if __name__ == "__main__":

    ## Load cmd line args
    parser = argparse.ArgumentParser(description='Genetic algorithm for dataset device sizing.')

    parser.add_argument('--ds_path', default='datasets/CktBench101/raw/OCB101v2_graphs', type=str,
                        help='The path to the dataset.')
    parser.add_argument('--start_k', dest='start_k', type=int, required=True,
                        help='Starting iteration index.')
    
    args = parser.parse_args()

    with open(args.ds_path, 'rb') as f:
        dataset = pickle.load(f)

    k = 0
    start_k = 500 * args.start_k

    outputs = []
    for split_idx in range(len(dataset)):
        
        for graph_idx in range(len(dataset[split_idx])):

            if (k >= start_k) and (k < start_k + 500):
            
                graph = dataset[split_idx][graph_idx][0].copy()
                new_feats = genetic_algo(graph, pop_size=50, prune_prop=0.1, max_gen=10, alpha=1.0, x_prob=0.2, k=0.05)
                # y
                try:
                    graph.vs['feat'] = new_feats
                    sim = simulation(graph, compute_features=True)
                    g_rdm, ugw_rdm, pm_rdm = np.round(float(sim.gain[0] / 100), 3), np.round(float(sim.ugw / 1e9), 3), np.round(float(sim.pm / 90), 3)
                    new_y = (g_rdm, pm_rdm, ugw_rdm)
                except:
                    new_y = (0, 0, 0)

                outputs.append({'idx': k, 'feat': new_feats, 'y': new_y})

            k += 1

            if graph_idx % 100 == 0:
                with open(f'NEW_DS/logs/logs_{args.start_k}.txt', 'w') as f:
                    f.write(f'graph_idx: {graph_idx} done!')           

            
    with open(os.path.join('NEW_DS', f'from_{start_k}_to_{start_k + 500}'), 'wb') as f:
        pickle.dump(outputs, f)   