# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib.pyplot as plt

from agents.sfql import SFQL
from agents.ql import QL
from features.tabular import TabularSF
from tasks.gridworld import Shapes
from utils.config import parse_config_file
from utils.stats import OnlineMeanVariance

# general training params
config_params = parse_config_file('gridworld.cfg')
gen_params = config_params['GENERAL']
test_params = config_params['TEST']
task_params = config_params['TASK']
agent_params = config_params['AGENT']
sfql_params = config_params['SFQL']
ql_params = config_params['QL']


# tasks
def generate_task():
    # rewards = dict(zip(['1', '2', '3'], list(np.random.uniform(low=-1.0, high=1.0, size=3)))) # CHANGE
    # fixed rewards to be clean and consistent
    rewards = dict(zip(['1', '2', '3'], [0.5, 0.25, -0.9]))
    return Shapes(maze=np.array(task_params['maze']), shape_rewards=rewards)
 

# agents
sfql = SFQL(TabularSF(**sfql_params), **agent_params) 
ql = QL(**agent_params, **ql_params)
# only SFQL agents
agents = [sfql]
names = ['SFQL']

# train
data_task_return = [OnlineMeanVariance() for _ in agents]
n_trials = test_params['n_trials']
n_samples = test_params['n_samples']
n_tasks = test_params['n_tasks']
for trial in range(n_trials):
    
    # train each agent on a set of tasks
    for agent in agents:
        agent.reset()
    for t in range(n_tasks):
        task = generate_task()

        # print task
        print('\nShapeIds:')
        print(task.shape_ids)
        print('\n')

        for agent, name in zip(agents, names):
            print('\ntrial {}, solving with {}'.format(trial, name))
            agent.train_on_task(task, n_samples)
             
    # update performance statistics 
    for i, agent in enumerate(agents):
        data_task_return[i].update(agent.reward_hist)

# plot the task return
ticksize = 14
textsize = 18
figsize = (20, 10)

plt.rc('font', size=textsize)  # controls default text sizes
plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
plt.rc('legend', fontsize=ticksize)  # legend fontsize

plt.figure(figsize=(12, 6))
ax = plt.gca()
for i, name in enumerate(names):
    mean = data_task_return[i].mean
    n_sample_per_tick = n_samples * n_tasks // mean.size
    x = np.arange(mean.size) * n_sample_per_tick
    se = data_task_return[i].calculate_standard_error()
    plt.plot(x, mean, label=name)
    ax.fill_between(x, mean - se, mean + se, alpha=0.3)
plt.xlabel('sample')
plt.ylabel('cumulative reward')
plt.title('Cumulative Training Reward Per Task')
plt.tight_layout()
plt.legend(ncol=2, frameon=False)
fig_name='sfql_return_test' #CHANGE
plt.savefig(f'figures/{fig_name}.png')

# visualizing the SFs
print('\nClustering SFs\n')
sf_dict = agents[0].sf.psi[0]
print('Num. SFs stored :',len(sf_dict))
sf_list = []
sf_list_ = []
for state in agents[0].sf.psi[0]:
    sf_list.append(sf_dict[state].flatten())
    sf_list_.append(sf_dict[state])


import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import decomposition
from sklearn import cluster
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
import copy

# Dim reduction

# 2D
pca2 = decomposition.PCA(n_components=2)

X_ = np.array(sf_list)
# X = X_@X_.T
x_red2 = pca2.fit_transform(X_)
plt.figure(figsize=(12, 6))
ax = plt.gca()
plt.scatter(x_red2[:, 0], x_red2[:, 1])
fig_name='SF visualization (2D)'
plt.savefig(f'figures/{fig_name}.png')

# 3D
pca3 = decomposition.PCA(n_components=3)

x_red3 = pca3.fit_transform(X_)
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')
ax.scatter(x_red3[:, 0], x_red3[:, 1], x_red3[:, 2])
fig_name='SF visualization (3D)'
plt.savefig(f'figures/{fig_name}.png')

# Clustering SFs

kmeans = cluster.KMeans(n_clusters=4).fit(X_)
centroids = kmeans.cluster_centers_

# 2D
centroids_red2 = pca2.transform(centroids)

plt.figure(figsize=(12, 6))
ax = plt.gca()
plt.scatter(x_red2[:, 0], x_red2[:, 1])
plt.scatter(centroids_red2[:, 0], centroids_red2[:, 1], s=50, c='r')
fig_name='SF cluster visualization (2D)'
plt.savefig(f'figures/{fig_name}.png')


# 3D
centroids_red3 = pca3.transform(centroids)

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')
ax.scatter(x_red3[:, 0], x_red3[:, 1], x_red3[:, 2])
ax.scatter(centroids_red3[:, 0], centroids_red3[:, 1], centroids_red3[:, 2], 'ro', s=50)
fig_name='SF cluster visualization (3D)'
plt.savefig(f'figures/{fig_name}.png')

# Find closest state to each centroid

# get cluster centers with K-means
X_cluster = kmeans.predict(X_)

# Find closest states to centroids
states = np.zeros(4,)
state_dist = 10000 * np.ones(4,)
for i, x in enumerate(X_):
    h = X_cluster[i]
    d = np.linalg.norm(centroids[h] - x)
    if d <= state_dist[h]:
        state_dist[h] = d
        states[h] = i


print('Distances to centroid from closest state :', state_dist)
print('State index :', states)

# initialize state vizualizer
state_maze0 = np.array(copy.deepcopy(task.maze))
state_maze0[np.where(state_maze0==' ')] = 0
state_maze0[np.where(state_maze0=='X')] = 9
state_maze0[np.where(state_maze0=='G')] = 6
state_maze0[np.where(state_maze0=='_')] = 0
state_maze0 = state_maze0.astype('float64')
# visulaize initial grid
plt.figure(figsize=(12, 6))
ax = plt.gca()
plt.pcolormesh(state_maze0, cmap='tab20c', edgecolor='k', linewidth=2)
plt.xticks([])
plt.yticks([])
plt.colorbar()
ax.set_aspect('equal')
ax.invert_yaxis()
plt.title(f'Object rewards: {task.shape_rewards}')
fig_name=f'Initial Grid'
plt.savefig(f'figures/{fig_name}.png')

for h in range(4):
    for key in list(sf_dict.keys()):
        if sf_dict[key] is sf_list_[int(states[h])]:
            # # print state
            # print(key)
            # visualize the state
            state_maze = copy.deepcopy(state_maze0)
            (r, c), coll = key
            # indicate collected rewards
            for i, key in enumerate(list(task.shape_ids.keys())):
                # remove object if collected
                if coll[i] == 1:
                        state_maze[key] = 0
            # mark grid position
            state_maze[r, c] = 4
            # visulaize state as image
            plt.figure(figsize=(12, 6))
            ax = plt.gca()
            plt.pcolormesh(state_maze,  cmap='tab20c', edgecolor='k', linewidth=2)
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()
            ax.set_aspect('equal')
            ax.invert_yaxis()
            plt.title(f'Object rewards: {task.shape_rewards}')
            fig_name=f'Cluster {h} state'
            plt.savefig(f'figures/{fig_name}.png')
            break
    # print(list(sf_dict.keys())[list(sf_dict.values()).index(sf_list_[int(states[h])])]) 


# Find cluster bottlenecks

bottleneck_dict = {}
centeroids_dist_dict = {}
for i, x in enumerate(centroids):
    for j, x in enumerate(centroids):
        if i < j:
            bottleneck_dict[(i, j)] = ( centroids[i] + centroids[j] ) / 2
            centeroids_dist_dict[(i, j)] = np.linalg.norm( centroids[i] - centroids[j] )
        
print('\nDistance between SF centeroids :', centeroids_dist_dict)

# Find bottleneck SFs

bottleneck_sf_dict = {}
for bottleneck in bottleneck_dict:
    temp_min_dist = 1000
    for x in X_:
        if np.linalg.norm( bottleneck_dict[bottleneck] - x ) <= temp_min_dist:
             bottleneck_sf_dict[bottleneck] = x
             temp_min_dist = np.linalg.norm( bottleneck_dict[bottleneck] - x )


# print('\nBottleneck SFs :', bottleneck_sf_dict)


for bottleneck in bottleneck_sf_dict:
    for state in sf_dict:
        # check the state corresponding to bottleneck SF
        if (sf_dict[state].flatten() == bottleneck_sf_dict[bottleneck]).all():
            # visualize the state
            state_maze = copy.deepcopy(state_maze0)
            (r, c), coll = state
            # indicate collected rewards
            for i, key in enumerate(list(task.shape_ids.keys())):
                # remove object if collected
                if coll[i] == 1:
                        state_maze[key] = 0
            # mark grid position
            state_maze[r, c] = 4
            # visulaize state as image
            plt.figure(figsize=(12, 6))
            ax = plt.gca()
            plt.pcolormesh(state_maze,  cmap='tab20c', edgecolor='k', linewidth=2)
            plt.xticks([])
            plt.yticks([])
            plt.colorbar()
            ax.set_aspect('equal')
            ax.invert_yaxis()
            plt.title(f'Object rewards: {task.shape_rewards}')
            fig_name=f'Bottleneck {bottleneck} state'
            plt.savefig(f'figures/{fig_name}.png')
            break


# finding the cluster centres using SA

# init learning rate for SA
alpha = 0.1

# initialize counter of centroid updates
v = np.ones((4,))

# clustering SFs given centroids
def cluster(mu, X):
    cluster_idx = []
    for x in X:
        for i, m in enumerate(mu):
            dist_xmu = np.linalg.norm( m - x )
            if i==0:
                temp_dist = dist_xmu
                temp_h = i
            else:
                if dist_xmu < temp_dist:
                    temp_dist = dist_xmu
                    temp_h = i
        cluster_idx.append(copy.deepcopy(temp_h))
    
    return np.array(cluster_idx)

# initialize cluster centroids from SFs
init_idx = np.random.choice(X_.shape[0], size=4)
mu = copy.deepcopy(X_[init_idx])

# get initial clusters
cluster_x = cluster(mu, X_)

# iteratively update cluster centroids
for t in range( X_.shape[0] ):

    # sample SF
    i = np.random.choice(X_.shape[0])
    # pick the cluster idx of SF
    h = cluster_x[i]
    # update cluster centroid corresponding to h
    mu[h] = mu[h] + 2* alpha / np.sqrt(v[h]) * (X_[i] - mu[h])
    # update centroid update counter
    v[h] = v[h] + 1
    # update clusters
    cluster_x = cluster(mu, X_)

print('\nCentroids from K-means :', )
print(centroids)
print('\nCentroids from SAs :', )
print(mu)

# 2D
mu_red2 = pca2.transform(mu)

plt.figure(figsize=(12, 6))
ax = plt.gca()
plt.scatter(x_red2[:, 0], x_red2[:, 1])
plt.scatter(mu_red2[:, 0], mu_red2[:, 1], s=50, c='r')
fig_name='SF cluster visualization - centroids from SA (2D)'
plt.savefig(f'figures/{fig_name}.png')


# 3D
mu_red3 = pca3.transform(mu)

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')
ax.scatter(x_red3[:, 0], x_red3[:, 1], x_red3[:, 2])
ax.scatter(mu_red3[:, 0], mu_red3[:, 1], mu_red3[:, 2], 'ro', s=50)
fig_name='SF cluster visualization  - centroids from SA (3D)'
plt.savefig(f'figures/{fig_name}.png')

