import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.animation as animation
import datetime
import torch
import pickle
import time
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
import random
import torch.optim as optim
import pandas as pd
import scipy.stats
import os
from PIL import Image


def torch_fix_seed(seed):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    
    
def map_dist(zahyou1):
    distances = torch.zeros(len(zahyou1), len(zahyou1))
    for i in range(len(zahyou1)):
        for j in range(len(zahyou1)):
            distance = torch.norm(zahyou1[i] - zahyou1[j])
            distances[i, j] = distance
            distances_normalized = distances/ distances.max(dim=1)[0]

    return (distances).reshape(len(zahyou1),len(zahyou1))


        
class image_dist(nn.Module):
    def __init__(self):
        super().__init__()
        rand = 0.05 + 0.15 * torch.rand(256*256)
        self.U = nn.Parameter(rand)
        
    def forward(self, x):
        batch_size = len(x)
        x = x.reshape((batch_size, 256*256, 1))
        UX = torch.einsum('k,ikj->ik', self.U, x)
        diff = (torch.abs(UX.unsqueeze(1) - UX.unsqueeze(0)))**2
        diffarence = torch.sum(diff, dim=2)
        diffarence_normalized = diffarence/ diffarence.max(dim=1)[0]

        return (diffarence).reshape(len(layout_images_data),len(layout_images_data))

    

class DataSet:
    def __init__(self,layout_images_data,zahyou):
        self.X = layout_images_data
        self.t = zahyou

    def __len__(self):
        return len(self.t) 

    def __getitem__(self, index):
        # index番目の入出力ペアを返す
        return self.X[index], self.t[index]
        return len(self.data)        

def load_gif_images_from_directory(directory):
    images = []
    for filename in os.listdir(directory):
        if filename.endswith('.gif'):
            img_path = os.path.join(directory,filename)
            img = Image.open(img_path)
            img = img.convert('L')
            img = img.resize((256,256))
            img_array = ToTensor()(img)
            images.append(img_array)
    return images


def load_png_images_from_directory(directory):
    images = []
    for filename in os.listdir(directory):
        if filename.endswith('.png'):
            img_path = os.path.join(directory,filename)
            img = Image.open(img_path)
            img = img.convert('L')
            img = img.resize((256,256))
            img_array = ToTensor()(img)
            images.append(img_array)
    return images

class crsom(nn.Module):
    def __init__(self,map_size,W):
        super().__init__()
        self.W = (nn.Parameter(torch.Tensor(W)))
        self.base = (torch.tensor([[[i,j] for i in range(map_size)] for j in range(map_size)])).to(device)
        
        
    def forward(self, x,t,rRBF_epoch,sigma,map_size,U):
        batch_size = len(x)
        x = (x.reshape((batch_size,256*256,1)))
        UW = torch.einsum('k,ijkl->ijk', U, self.W.unsqueeze(-1)).unsqueeze(-1)
        UX = torch.einsum('k,ikj->ik', U, x).unsqueeze(1).unsqueeze(1).unsqueeze(-1)
        diffarence = (((UX-UW)**2).sum(axis=3))
        diffarence_shape = (diffarence.reshape(batch_size,map_size*map_size))
        win_index = (torch.argmin(diffarence_shape,axis=1))
        wini = ((win_index / map_size).to(torch.int64))
        winj = (win_index % map_size)
        s_t = (torch.tensor(200*(10/200)**(t/(rRBF_epoch)))).to(device)
        win_neuron = torch.stack([wini,winj],dim=1)
        win_neuron = win_neuron.reshape(batch_size,1,1,2)
        map_distance = (torch.sqrt(torch.sum(torch.sub(win_neuron,self.base)**2,dim=3))).to(device)
        sigma_winij = (torch.exp(-map_distance/s_t))
        O_hij = (torch.exp(-diffarence.reshape(batch_size,map_size,map_size)/sigma)*sigma_winij)
        O_hij = (O_hij.reshape(batch_size,map_size*map_size))
        return O_hij
        
        
class rRBF(nn.Module):
    def __init__(self,map_size,W):
        super().__init__()
        self.CRSOM = crsom(map_size,W)
        self.classifier = nn.Linear(in_features= map_size* map_size , out_features=2)
        
    def forward(self, x,t,rRBF_epoch,sigma,map_size,U):
        x = (self.CRSOM(x,t,rRBF_epoch,sigma,map_size,U))
        x = (self.classifier(x))
        
        return x  


colors_10 = plt.rcParams['axes.prop_cycle'].by_key()['color']



#Layout image data
layout_images_data = load_png_images_from_directory('Preparation data/sample_Images')


device=torch.device('cpu')


names = ['person2']


eps = 0.001

map_size_list=[25]
heatmap_epoch_list=[300]

rRBF_epoch = 300
sigma = 1
map_size = 25


#rRBF data
data_0_path = 'Preparation data/train/0'
data_1_path = 'Preparation data/train/1'

data_0_images = load_gif_images_from_directory(data_0_path)
data_1_images = load_gif_images_from_directory(data_1_path)

data_0_labels = torch.zeros(len(data_0_images), dtype=torch.int64)
data_1_labels =  torch.ones(len(data_1_images), dtype=torch.int64)

data_images = data_0_images + data_1_images
data_labels = torch.cat([data_0_labels, data_1_labels], dim=0)

test_data = [data_images[i] for i in range(0, len(data_images), 4)]
train_data = [data_images[i] for i in range(len(data_images)) if i not in range(0, len(data_images), 4)]

test_label = [data_labels[i] for i in range(0, len(data_labels), 4)]
train_label= [data_labels[i] for i in range(len(data_labels)) if i not in range(0, len(data_labels), 4)]

train_dataset = DataSet(train_data,train_label)
trainloader = torch.utils.data.DataLoader(train_dataset,
                                         batch_size=10,
                                        shuffle=True)

test_dataset = DataSet(test_data,test_label )
testloader = torch.utils.data.DataLoader(test_dataset,
                                         batch_size=10,
                                        shuffle=True)

for map_size in map_size_list:
    seed_count=0        
    for name in names:
        device=torch.device('cpu')
        print(device)
        print(name,seed_count,map_size)
        print('initilize')
        #Coordinates of layout image data
        f_zahyou = open('Preparation data/trans_' + name + '.txt','rb')
        zahyou = pickle.load(f_zahyou)
        f_zahyou.close()
        zahyou_reshape = [[int(n[0]/25*map_size),int(n[1]/25*map_size)] for n in zahyou]
        left_side_minus_25 = [map_size-x for x in list(zip(*zahyou_reshape))[1]]
        plt.axes().set_aspect('equal')
        plt.xlim(-0.5,map_size+0.5)
        plt.ylim(-0.5,map_size+0.5)
        plt.grid()
        plt.scatter(list(zip(*zahyou_reshape[ 0:10]))[0], left_side_minus_25[ 0:10], color=colors_10[0], label='0',marker='$0$')
        plt.scatter(list(zip(*zahyou_reshape[10:]))[0], left_side_minus_25[10:], color=colors_10[1], label='1',marker='$1$')
        plt.savefig('result/'+name + '_'+str(map_size)+'_init.png')
        plt.show()
        
        
        print('heatmap')
        model_U = image_dist()
        optimizer_U = optim.Adam(model_U.parameters(),lr=0.001)
        
        torch_fix_seed(seed=42+seed_count)

        U_dataset = DataSet(layout_images_data,zahyou)
        U_trainloader = torch.utils.data.DataLoader(U_dataset,
                                             batch_size=len(layout_images_data),
                                            shuffle=True)

        Loss=[]
        min_loss = 300000
        loss_over = 0

        map_epoch = map_size_list.index(map_size)
        heatmap_epoch = heatmap_epoch_list[map_epoch]

        criterion = nn.MSELoss()
        model_U.train() 
        for epoch in range(heatmap_epoch):
            for data, target in  U_trainloader:
                map_zahyou=torch.stack([target[0], target[1]], dim=1)

                data, target = Variable(data), Variable(map_zahyou)
                optimizer_U.zero_grad()

                map_dist_matrix = map_dist(target.float())
                map_dist_matrix = Variable(map_dist_matrix)

                output = model_U(data.float())

                loss = (criterion(output.reshape(len(layout_images_data)*len(layout_images_data),1),map_dist_matrix.reshape(len(layout_images_data)*len(layout_images_data),1)))
                loss.backward()
                Loss.append(loss.to('cpu').detach().numpy().copy())
                
                optimizer_U.step()
                
                if Loss[epoch] <= min_loss:
                    min_loss=Loss[epoch]
                    torch.save(model_U.state_dict(), 'result/model_U_OASIS_'+ name +'.pth')
        model_U.U.requires_grad=False
        U_diagonal = (model_U.U)
        U_max = torch.max(U_diagonal)
        U_min = torch.min(U_diagonal)
        normalized_U_diagonal = (U_diagonal - (U_min)) / (U_max - (U_min))
        cmap = colors.LinearSegmentedColormap.from_list('custom_cmap', ['blue', 'white', 'red'])

        plt.imshow((normalized_U_diagonal).reshape(256,256), cmap=cmap, vmin=0,vmax=1, extent=[0, 256, 0, 256])
        plt.colorbar()
        plt.xticks([0, 256])
        plt.yticks([0, 256])
        plt.savefig('result/heat map_'+str(heatmap_epoch)+'_'+ name +'_U_normalized_map'+str(map_size)+'.png')
        plt.show()
        
        
        print('W')
        W = np.full((map_size, map_size, 256*256), 1).astype(np.float64)
        for k in range(len(zahyou_reshape)):
            W[zahyou_reshape[k][1]][zahyou_reshape[k][0]] = layout_images_data[k].reshape(256*256)


        W_epoch = 5
        for epoch in range(W_epoch):
            for i in range(map_size):
                for j in range(map_size):
                    if ([j,i] in zahyou_reshape) == False:
                        base = (np.array([j,i]))
                        distances = np.sqrt(np.sum((base - zahyou_reshape)**2, axis=1))
                        k_dist = np.argsort(distances)
                        W[i,j] += ((-1/(distances[k_dist[0]]) * (W[i,j]-W[zahyou_reshape[k_dist[0]][1]][zahyou_reshape[k_dist[0]][0]])
                                   +-1/(distances[k_dist[1]]) * (W[i,j]-W[zahyou_reshape[k_dist[1]][1]][zahyou_reshape[k_dist[1]][0]])
                                   +-1/(distances[k_dist[2]]) * (W[i,j]-W[zahyou_reshape[k_dist[2]][1]][zahyou_reshape[k_dist[2]][0]]))/3)
        
        
        print('rRBF')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(device)
        print(name,eps,sigma,map_size,heatmap_epoch)
        U_matrix = U_diagonal
        U_matrix = U_matrix.to(device)

        model = rRBF(map_size,W)
        model.to(device)

        optimizer_rRBF = optim.Adam([
            {'params': model.parameters()}
        ], lr=0.001)

        Loss=[]
        acc=[]
        
        criterion = nn.CrossEntropyLoss()

        min_loss = 10
        loss_over = 0
        model.train() 
        for epoch in range(rRBF_epoch):
            train_count=0
            out = []
            sum_loss = []
            T = []
            for inputs, labels in  trainloader:    
                data, target = Variable(inputs), Variable(labels)#微分可能な型
                data=data.to(device)
                target=target.to(device)
                #print('data',data.is_cuda)
                optimizer_rRBF.zero_grad()


                output = model(data,epoch,rRBF_epoch,sigma=sigma,map_size=map_size,U=U_matrix)
                out.append(output)
                T.append(target)
                loss = (criterion(output,target))
                sum_loss.append(loss.to('cpu').detach().numpy().copy()) 

                loss.backward()

                optimizer_rRBF.step()

                for i in range(len(target)):
                    if (target[i] == (torch.argmax(output,axis=1)[i])):
                        train_count +=1
                

            acc.append(train_count/len(train_dataset))
            Loss.append(np.mean(sum_loss))
            if Loss[epoch] <= min_loss:
                min_loss=Loss[epoch]
                torch.save(model.state_dict(),'result/model_OASIS_'+ name +'.pth')

            
            if epoch >=1:
                if -eps < Loss[epoch] - Loss[epoch-1] < eps:
                    loss_over += 1
                else:
                    loss_over = 0

                if loss_over >=5  :
                    break
            


        plt.plot(range(len(Loss)),Loss,'b')
        plt.xlabel('epoch',fontsize=18)
        plt.ylabel('loss',fontsize=18)
        plt.savefig('result/OASIS_trainingloss_'+ name +'.png')
        plt.show()
        plt.plot(range(len(acc)),acc,'r')
        plt.ylim(0,1)
        plt.xlabel('epoch',fontsize=18)
        plt.ylabel('acc',fontsize=18)
        #plt.savefig('result/OASIS_acc_'+ name +'.png')
        plt.close()



        row=map_size
        col=row
        OASIS_map = np.zeros((row, col, 3))+255

        model.load_state_dict(torch.load('result/model_OASIS_'+ name +'.pth'))
        model_weight=((model.CRSOM.W.to('cpu').detach().numpy().copy()))


        class_0_i = []
        class_0_j = []
        class_1_i = []
        class_1_j = []

        scale_size = 0.3

        def i_map(input_data,count,train_y,scale_size):
            input_data=(input_data.to('cpu').detach().numpy().copy())
            input_data=input_data.reshape(256*256)
            min_index = np.argmin(((model_weight-input_data)**2).sum(axis=2))
            #if train_y[count] ==1:
                #print(min_index,train_y[count])
            mini = int(min_index / row)
            minj = int(min_index % col)
            if train_y[count] ==0:
                class_0_i.append(row-mini+np.random.uniform(-scale_size,scale_size))
                class_0_j.append(minj+np.random.uniform(-scale_size,scale_size))
                OASIS_map[mini,minj] = (255,0,0)
            elif train_y[count] == 1:
                class_1_i.append(row-mini+np.random.uniform(-scale_size,scale_size))
                class_1_j.append(minj+np.random.uniform(-scale_size,scale_size))
                OASIS_map[mini,minj] = (0,255,0)
        

        model.eval()

        test_count=0

        Loss_test = []
        acc_test = []
        sum_loss_test = []

        with torch.no_grad():
            for inputs, labels in  testloader:
                data, target = Variable(inputs), Variable(labels)#微分可能な型
                data=data.to(device)
                target=target.to(device)

                output = model(data,epoch,rRBF_epoch,sigma=sigma,map_size=map_size,U=U_matrix)


                loss = (criterion(output,target))
                sum_loss_test.append(loss.to('cpu').detach().numpy().copy()) 


                for i in range(len(target)):
                        if (target[i] == torch.argmax(output,axis=1)[i]):
                            test_count +=1

        acc_test.append(test_count/len(test_dataset))
        Loss_test=(np.mean(sum_loss_test))
        print('test_accuracy',acc_test)
        print('test_loss',Loss_test)
        formatted_accuracy = f'{acc_test[0]:.3g}'

        with open('result/test_accuracy_OASIS_'+ name +'.txt', 'w') as file:
            file.write(formatted_accuracy)


        for test_images,test_labels in testloader:
            for time in range(len(test_images)):
                color = test_images[time]
                i_map(color,time,test_labels,scale_size)

        m_size =20
        plt.axes().set_aspect('equal')
        plt.xlim(-0.5,row+0.5)
        plt.ylim(-0.5,col+0.5)
        plt.grid()
        plt.scatter(class_0_j,class_0_i,s=m_size,c=colors_10[0],marker='$0$')
        plt.scatter(class_1_j,class_1_i,s=m_size,c=colors_10[1],marker='$1$')
        plt.savefig('result/CRSOM_'+ name +'.png')
        plt.show()


        seed_count+=1
