import statistics
import matplotlib.pyplot as plt

import numpy as np
from math import sqrt

import os

import pandas as pd
import torch
import torch.nn as nn
from torch.nn import init

import torch.nn.functional as F
from scipy import spatial

#arguments for NEC labels training
num_obs_NEC_nn = 12
num_NEC_labels = 40
NEC_nn_lr = 0.002
traning_steps = 70

# import our s1, Q2 data here
agent0_qtable = pd.DataFrame(pd.read_csv(""))
agent1_qtable = pd.DataFrame(pd.read_csv(""))

def convertQtable(q_table):
    list_o = []
    list_q = []
    for i in range(len(q_table)):
        string_o = q_table.iloc[i][1][1:-1:].split()
        string_o_1 = string_o[0:8]
        string_o_2 = string_o[10:14]
        string_o_1.extend(string_o_2)
        string_o = string_o_1
        list_o.append([float(num) for num in string_o])

        string_q = q_table.iloc[i][2][2:-2:].split('], [')
        list_q.append([float(num) for num in string_q])

    array = np.array(list_o)
    x = torch.tensor(array, dtype=torch.float32)

    return x, list_q, list_o

# define the NEC network architecture
class NEC(nn.Module):
    def __init__(self, n_feature, n_class):
        super(NEC, self).__init__()
        self.fc1 = nn.Linear(n_feature, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.NEC_out = nn.Linear(1200, n_class)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        NEC_logit = self.NEC_out(x)

        return NEC_logit

def entropy(p):
    if p.data.ndim == 2:
        return - torch.sum(p * torch.log(p + 1e-8)) / float(len(p.data))
    elif p.data.ndim == 1:
        return - torch.sum(p * torch.log(p + 1e-8))
    else:
        raise NotImplementedError
def loss_equal(net, x):
    p_logit = net(x)
    p = F.softmax(p_logit)
    p_ave = torch.sum(p, dim=0) / x.data.shape[0]
    ent = entropy(p)
    return ent, -torch.sum(p_ave * torch.log(p_ave + 1e-8))

def euclidean_distance(row1, row2):
    distance = 0.0
    for i in range(len(row1) - 1):
        distance += (row1[i] - row2[i]) ** 2
    return sqrt(distance)

def get_neighbors(train, test_row, num_neighbors):  # get 16 neighbors
    distances = list()
    for i, train_row in enumerate(train):
        dist = spatial.distance.cosine(test_row, train_row) #TODO cosine similarity
        distances.append((train_row, dist, i))
    distances.sort(key=lambda tup: tup[1])
    neighbors = list()
    for i in range(num_neighbors):
        neighbors.append((distances[i][0], distances[i][2]))
    return neighbors

def get_norm_by_mean(list_q):
    norm_list_q = []
    for i in range(len(list_q)):
        mean_q_i = sum(list_q[i]) / len(list_q[i])
        norm_q_i = [(number - mean_q_i) / 1 for number in list_q[i]]
        norm_list_q.append(norm_q_i)
    return norm_list_q

def get_norm_and_tanh(list_q):
    get_norm_and_tanh = []
    for i in range(len(list_q)):
        max_q_i = max(list_q[i])
        min_q_i = min(list_q[i])
        b_q_i = (max_q_i + min_q_i) / 2

        norm_q_i_before_tanh = [(number - b_q_i) / max_q_i for number in list_q[i]]
        tanh_q_i = np.tanh(norm_q_i_before_tanh)
        tanh_q_i_list = tanh_q_i.tolist()
        get_norm_and_tanh.append(tanh_q_i_list)
    return get_norm_and_tanh


def loss_lp(network, x, list_q, list_o):  # x is the state of agent 1
    out = network(x)  # use output of NEC net
    NEC_label = torch.max(F.softmax(out), 1)[1]  # apply softmax and argmax to the out of NEC nn, get NEC discrete label
    pred_NEC = NEC_label.data.numpy().squeeze()  # from mofan pytorch intro

    # for each input vector, find its 16 nearest neighbors
    Q_i_neighbors = []
    loss = 0
    for i in range(len(list_q)):
        norm_list_q = get_norm_and_tanh(list_q)
        Q_i_neighbor = get_neighbors(norm_list_q, norm_list_q[i], 16)
        Q_i_neighbors.append(Q_i_neighbor)

        x_i = list_o[i]
        tensor_x_i = torch.tensor(x_i, dtype=torch.float32)
        out_i = network(tensor_x_i)  # use output of NEC net
        out_softmax_i = F.softmax(out_i).data.tolist()

        NEC_label_i = torch.argmax(F.softmax(out_i))  # apply softmax and argmax to the out of NEC nn, get NEC discrete label
        pred_NEC_i = NEC_label_i.data.tolist()
        #mean_q_i = sum(list_q[i])/len(list_q[i])
        #var_q_i = statistics.variance(list_q[i])
        #var_2_q_i = sum((i - mean_q_i) ** 2 for i in list_q[i]) / len(list_q[i])
        #std_q_i = var_q_i ** 0.5
        #norm_q_i = [(number - mean_q_i) / 1 for number in list_q[i]]
        #list_norm_q_i = torch.tensor(norm_q_i, dtype=torch.float32).tolist()

        #tanh_q = torch.tanh(torch.tensor(norm_q_i, dtype=torch.float32)).tolist()
        loss_x_i = 0
        for ne in Q_i_neighbor:
            x_j = list_o[ne[1]]
            tensor_x_j = torch.tensor(x_j, dtype=torch.float32)
            out_j = network(tensor_x_j)  # use output of NEC net
            out_softmax_j = F.softmax(out_j).data.tolist()

            NEC_label_j = torch.argmax(F.softmax(out_j))  # apply softmax and argmax to the out of NEC nn, get NEC discrete label
            pred_NEC_j = NEC_label_j.data.tolist() # from mofan pytorch introx
            #print(pred_NEC_j)
            ne_list = ne[0]

            #mean_ne_i = sum(ne_list) / len(ne_list)
            #var_ne_i = statistics.variance(ne_list)
            #var_2_ne_i = sum((i - mean_ne_i) ** 2 for i in ne_list) / len(ne_list)
            #std_ne_i = var_ne_i ** 0.5
            #norm_ne = [(number - mean_ne_i) / 1 for number in ne_list]
            #list_norm_ne = torch.tensor(norm_ne, dtype=torch.float32).tolist()

            #ne_tanh_q = torch.tanh(torch.tensor(norm_ne, dtype=torch.float32)).tolist()
            #dis = euclidean_distance(tanh_q, ne_tanh_q) # check this distance when label is the same and different
            dis = 1 - spatial.distance.cosine(norm_list_q[i], ne_list) # check this distance when label is the same and different
            #dis = 1 - spatial.distance.cosine(list_norm_q_i, ne_list ) # check this distance when label is the same and different

            #dis_Y_logits = euclidean_distance(out_softmax_i, out_softmax_j)


            loss_x_i += dis * (pred_NEC_i - pred_NEC_j)**2
            #loss_x_i += dis * dis_Y_logits


        loss += loss_x_i

    return loss

def runner(policy, x, list_q, list_o):
    out = policy.network(x)
    NEC_label = []
    for ou in out:
        NEC_label_i = torch.argmax(F.softmax(ou))  # apply softmax and argmax to the out of NEC nn, get NEC discrete label
        pred_NEC_i = NEC_label_i.data.tolist()
        NEC_label.append(pred_NEC_i)

    loss_eq1, loss_eq2 = loss_equal(policy.network, x)
    loss_eq = loss_eq1 - 4 * loss_eq2
    loss_part1 = loss_lp(policy.network, x, list_q, list_o)
    loss_1 = 0.0001 * loss_part1
    loss_2 = 10 * loss_eq
    loss = loss_1 + loss_2

    policy.optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    policy.optimizer.step()  # apply gradients
    losses = loss.data.tolist()  # from mofan pytorch introx




    return NEC_label, losses

def label_distance_and_count_A(NEC_label, list_q, agent_id):
    label_dis = dict()
    norm_tanh_list_q = get_norm_and_tanh(list_q)

    for i in range(num_NEC_labels):
        label_dis[i] = []

    for i in range(len(NEC_label)):
        label_dis[NEC_label[i]].append(norm_tanh_list_q[i])

    sum_max_A = 0
    for i in range(num_NEC_labels):
        lists = label_dis[i]
        result_list = lists[0]
        for j in range(1, len(lists)):
            current_list = lists[j]
            zipped_list = zip(result_list, current_list)
            sum_list = [x + y for (x, y) in zipped_list]
            result_list = sum_list
        max_element = max(result_list)
        sum_max_A += max_element
    print("A equals to: ", sum_max_A, "for agent number:", agent_id)

    count = []
    for i in range(num_NEC_labels):
        count.append(len(label_dis[i]))

    label_num = pd.DataFrame(count)
    label_num.to_csv('./QDT_outputs/exp1_QDT_100Labels_label_count_stage_2_{}.csv'.format(agent_id))

    with_in_cluster_dis = []
    mean_list = []
    for key in label_dis.keys():
        if len(label_dis[key]) == 0:
            mean_list.append([])
            with_in_cluster_dis.append([])
        else :
            np_q_vector = np.array(label_dis[key])
            mean = np_q_vector.mean(axis=0)
            mean_list.append(mean)
            sum = 0
            for row in np_q_vector:
                sum += spatial.distance.cosine(mean, row)
            dis = sum / len(np_q_vector)
            with_in_cluster_dis.append(dis)

    different_label_dis = []
    for i in range(num_NEC_labels):
        current_label_dis = []
        current = mean_list[i]
        if len(current) == 0:
            for j in range(num_NEC_labels):
                current_label_dis.append(0)
        else:
            for j in range(num_NEC_labels):
                if len(mean_list[j]) != 0:
                    dis = spatial.distance.cosine(current, mean_list[j])
                    current_label_dis.append(dis)
                else:
                    current_label_dis.append(0)
        different_label_dis.append(current_label_dis)

    for i in range(len(with_in_cluster_dis)):
        if not with_in_cluster_dis[i]:
            with_in_cluster_dis[i] = 0

    df = pd.DataFrame(different_label_dis)
    df['with_in_cluster_ave_dis'] = with_in_cluster_dis

    df.to_csv('./QDT_outputs/exp1_QDT_100Labels_cluster_CosDis_stage_2_{}.csv'.format(agent_id))


def q_distance(list_q, agent_id):
    overall_dis = []
    length = len(list_q)
    for i in range(length):
        current_dis = []
        for j in range(length):
            if i == j:
                current_dis.append(0)
            else:
                current_dis.append(spatial.distance.cosine(list_q[i], list_q[j]))
        overall_dis.append(current_dis)
    df = pd.DataFrame(overall_dis)
    df.to_csv('./QDT_outputs/exp1_QDT_100Labels_Qv_CosDis_stage_2_{}.csv'.format(agent_id))

def calculate_B_C(list_q, agent_id):
    sum_B = 0
    norm_tanh_list_q = get_norm_and_tanh(list_q)


    for row in norm_tanh_list_q:
        sum_B += max(row)
    print("B equals to: ", sum_B, "for agent number:", agent_id)

    result_list = norm_tanh_list_q[0]
    for i in range(1, len(norm_tanh_list_q)):
        current_list = norm_tanh_list_q[i]
        zipped_list = zip(result_list, current_list)
        sum_list = [x + y for (x, y) in zipped_list]
        result_list = sum_list
    max_element = max(result_list)
    print("C equals to: ", max_element, "for agent number:", agent_id)

# define the NEC network

class policy:
    def __init__(self):
        self.network = NEC(n_feature=num_obs_NEC_nn, n_class=num_NEC_labels)  # n_feature need to be changed with input state dimension
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=NEC_nn_lr, betas=(0.9, 0.99))

x_0, list_q_0, list_o_0 = convertQtable(agent0_qtable)
x_1, list_q_1, list_o_1 = convertQtable(agent1_qtable)



q_distance(list_q_0, 0)
q_distance(list_q_1, 1)

calculate_B_C(list_q_0, 0)
calculate_B_C(list_q_1, 1)

net_0 = policy()
net_1 = policy()

steps = traning_steps
loss_list_0 = []
loss_list_1 = []
for t in range(steps):
    print("Current E: " , t)
    net_0.network.eval()
    net_1.network.eval()

    NEC_label_0, losses_0 = runner(net_0, x_0, list_q_0, list_o_0)
    NEC_label_1, losses_1 = runner(net_1, x_1, list_q_1, list_o_1)
    loss_list_0.append(losses_0)
    loss_list_1.append(losses_1)
    net_0.network.train()
    net_1.network.train()

    if t == steps - 1:
        label_distance_and_count_A(NEC_label_0, list_q_0, 0)
        label_distance_and_count_A(NEC_label_1, list_q_1, 1)
        agent0_qtable['NEC_label'] = NEC_label_0
        agent1_qtable['NEC_label'] = NEC_label_1
        agent0_qtable.to_csv('./QDT_outputs/exp1_QDT_100Labels_qlabel_stage_2_{}.csv'.format(0))
        agent1_qtable.to_csv('./exp1_QDT_100Labels_qlabel_stage_2_{}.csv'.format(1))


torch.save(net_0.network.state_dict(), './QDT_outputs/exp1_QDT_100Labels_stage2_0.pkl')
torch.save(net_1.network.state_dict(), './QDT_outputs/exp1_QDT_100Labels_stage2_1.pkl')

plt.figure()
plt.plot(range(len(loss_list_0)), loss_list_0, label = 'NEC 0')
plt.plot(range(len(loss_list_1)), loss_list_1, label = 'NEC 1')

plt.xlabel('step' )
plt.ylabel('training loss')
plt.legend(loc='lower left')
plt.savefig('./QDT_outputs/exp1_QDT_100Labels_stage_2.png', format='png')
plt.show()

