# -*- coding: utf-8 -*-


import numpy as np
import scipy.special as sp
import matplotlib.pyplot as plt
import copy
import torch
def GaussianMatrix(X,Y,sigma):
    size1 = X.size()
    size2 = Y.size()
    G = (X*X).sum(-1)
    H = (Y*Y).sum(-1)
    Q = G.unsqueeze(-1).repeat(1,size2[0])
    R = H.unsqueeze(-1).T.repeat(size1[0],1)
    
    
    H = Q + R - 2*X@(Y.T)
    H = torch.exp(-H/2/sigma**2)
    
    
    return H

def CSD_4(x1,x2,y1,y2,sigma = 1): # conditional cs divergence
    x1 = torch.tensor(x1)
    x2 = torch.tensor(x2)
    y1 = torch.tensor(y1)
    y2 = torch.tensor(y2)
    
    
    K1 = GaussianMatrix(x1,x1,sigma)
    K2 = GaussianMatrix(x2,x2,sigma)
    
    L1 = GaussianMatrix(y1,y1,sigma)
    L2 = GaussianMatrix(y2,y2,sigma)
    
    K12 = GaussianMatrix(x1,x2,sigma)
    L12 = GaussianMatrix(y1,y2,sigma)
    
    K21 = GaussianMatrix(x2,x1,sigma);
    L21 = GaussianMatrix(y2,y1,sigma);

    H1 = K1*L1
    self_term1 = (H1.sum(-1)/((K1.sum(-1))**2)).sum(0)
    
    H2 = K2*L2
    self_term2 = (H2.sum(-1)/((K2.sum(-1))**2)).sum(0)
    
    H3 = K12*L12;
    cross_term1 = (H3.sum(-1)/((K1.sum(-1))*(K12.sum(-1)))).sum(0)
    
    H4 = K21*L21;
    cross_term2 = (H4.sum(-1)/((K2.sum(-1))*(K21.sum(-1)))).sum(0)
    
    cs1 = -2*torch.log2(cross_term1) + torch.log2(self_term1) + torch.log2(self_term2)
    cs2 = -2*torch.log2(cross_term2) + torch.log2(self_term1) + torch.log2(self_term2)
    
    
    return ((cs1+cs2)/2).item()

class Open2d:
    def __init__(self, init_position = np.array([0,0]), stride = 1):
        self.init_position = init_position
        self.stride = stride
        self.state = None
        self.trace = np.zeros([203,203])
        self.action_map = {0: np.array([0,1]),
                   1: np.array([0,-1]),
                   2: np.array([1,0]),
                   3: np.array([-1,0]),
                   4: np.array([1,1]),
                   5: np.array([-1,-1]),
                   6: np.array([-1,1]),
                   7: np.array([1,-1]),
                   }
        self.trac_history = []
        
    def reset(self, position = np.array([0,0])):
        self.state = position
        self.trace[self.state[0]+100,self.state[1]+100] += 1
        return self.state
    
    def step(self, action):
        move = self.action_map[action]
        next_position = self.state + np.asarray(move)
        self.state = next_position
        self.trace[self.state[0]+100,self.state[1]+100] += 1
        self.trac_history.append(self.trace.copy())
        return self.state
    
    def predict(self, trace = False):
        
        xa = []
        y = []
        trace_p = []
        trace_p_h = []
        for i in range(len(self.action_map)):
            move = self.action_map[i]
            next_position = self.state + np.asarray(move)
            #xa.append(self.state/10)
            xa.append(np.concatenate((self.state/100,move)))
            y.append(next_position/100)
            if trace:
                add_n = np.zeros(self.trace.shape)
                add_n[next_position[0]+100,next_position[1]+100] = 1
                trace_p.append(self.trace+add_n)
                
                if len(self.trac_history)<20: 
                    trace_p_h.append(self.trace+add_n)
                else:
                    trace_p_h.append(self.trace-self.trac_history[-20]+add_n)
            
            
        return xa, y, trace_p#,trace_p_h
        
    
    
    def visual(self):
        plt.imshow(np.log(self.trace))       
        plt.show()
        
        

env =  Open2d()

for epoch in range(100):

    xa = []
    y = []
    deepth =50
    state = env.reset()
    diver_list = []
    
    start_action = np.random.randint(8)
    for i in range(200):
        #action = 5
        #next_state = env.step(action)
        #env.visual() 
        # if len(xa)>deepth:
        #     xa = xa[1:]
        #     y = y[1:]
        
        if 1:
            possible_xa, possible_y, _ = env.predict()
            diver = []
            for j in range(len(possible_y)):
                x_all = np.array(xa + [possible_xa[j]])
                y_all = np.array(y + [possible_y[j]])
                
                Tlen = i#deepth
                x1 = x_all[:Tlen//2]
                y_old =  y_all[:Tlen//2]
                
                        
                x2 = x_all[Tlen//2:]
                y_new = y_all[Tlen//2:]
                csd_our = CSD_4(x1,x2,y_old,y_new,sigma = 1)
                diver.append(csd_our)
            diver_list.append(diver)
            diver = np.array(diver)
            if np.isnan(diver).sum():
                action = np.random.randint(8)#start_action#
            else:
                action = np.random.choice(np.where(diver == diver.max())[0])
                #action = np.argmax(diver)
            #action = np.random.choice(np.flatnonzero(b == b.max()))
            
            xa.append(possible_xa[action])
            y.append(possible_y[action])
        # else:
        #     action = np.random.randint(8)
        next_state = env.step(action)   
          
        state = next_state
    
    
    env.visual()
    #xa = np.concatenate((state/100,env.action_map[action]))
    #y = next_state/100
    #state = next_state
  
    