import torch
import numpy as np
import matplotlib.pyplot as plt
from kmeans_pytorch import kmeans, kmeans_predict
import pickle
from sklearn.preprocessing import StandardScaler, LabelEncoder
from copy import deepcopy
from cluster.reward import get_reward_strongArm, get_reward_fold, get_reward_ng


def dict_to_tensor(dict):
    tensor = torch.zeros(len(dict))
    i = 0
    for k, v in sorted(dict.items()):
        tensor[i] = v
        i += 1
    return tensor


def choose_next(data: dict, env, num_clusters):
    # old_data = deepcopy(data)
    rewards = {}
    if env == 'strongArm':
        valid = {'delay', 'input_ref_noise'}
        for corner_info, perf in data.items():
            rewards[corner_info] = get_reward_strongArm(corner_info=corner_info, perf=perf)
            new_perf = {k: v for k, v in perf.items() if k in valid}
            data[corner_info] = new_perf
    elif env == 'fold':
        valid = {'gain', 'pm_dm', 'rms_noise_out_dm', 'lg_ugb'}
        for corner_info, perf in data.items():
            rewards[corner_info] = get_reward_fold(corner_info=corner_info, perf=perf)
            new_perf = {k: v for k, v in perf.items() if k in valid}
            data[corner_info] = new_perf
    elif env == 'ng':
        valid = {'gain', 'phm', 'ugbw'}
        for corner_info, perf in data.items():
            rewards[corner_info] = get_reward_ng(corner_info=corner_info, perf=perf)
            new_perf = {k: v for k, v in perf.items() if k in valid}
            data[corner_info] = new_perf

    # corner lookup table in sorted order
    corner_infos = []
    # convert data to tensors
    x = torch.zeros(len(data), len(next(iter(data.values()))))
    i = 0
    for key, value in sorted(data.items()):
        corner_infos.append(key)
        x[i] = (dict_to_tensor(value))
        i += 1
    # normalize
    standard = StandardScaler()
    x = standard.fit_transform(x)
    x = torch.from_numpy(x)
    # set random seed
    np.random.seed(123)
    # k-means
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    cluster_ids_x, cluster_centers = kmeans(
        X=x, num_clusters=num_clusters, distance='euclidean', device=device
    )
    # generate dict of groups(list)
    info_groups = {}
    for i in range(num_clusters):
        info_groups[str(i)] = []
    for index in range(list(cluster_ids_x.shape)[0]):
        gid = str(cluster_ids_x[index].item())
        info_groups[gid].append(corner_infos[index])
    # print('info_groups:', info_groups)
    # pick worst one for each group
    perf_choice = {}
    next_choice = {}
    for k, group in info_groups.items():
        min_reward = 1
        min_corner = "tt"
        for corner_info in group:
            reward = rewards[corner_info]
            if reward < min_reward:
                min_reward = reward
                min_corner = corner_info

        perf_choice[k] = {min_corner: min_reward}
        next_choice.update(perf_choice[k])

    import pprint
    print('next corners:')
    pprint.pprint(perf_choice)
    return next_choice
