#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sys
from tqdm import tqdm,trange
from sklearn.metrics.cluster import normalized_mutual_info_score

# In[176]:


class Magnum:
    def __init__(self,map_algorithms,input_shape,global_algorithms,dim=3,set_len=10):
        self.name="Magnum"
        self.map_algorithm=map_algorithms
        #self.map_algorithm1=map_algorithm1
        self.global_algorithms=global_algorithms
        self.set_num=int(input_shape/3)
        self.set=[]
        self.set_init=True
        self.maps=[]
        self.set_len=set_len
        self.input_size=input_shape
        self.exist_matrix=np.zeros(self.input_size,bool)
        self.dim=dim
    def make_set_with_global_map(self,x,global_dim=5):
        self.global_map=self.global_algorithms(self.input_size,global_dim
                                                ,adaptation_rate=0.001*self.input_size)
        self.global_map.input(x)
        self.set_start_point=np.random.choice(np.arange(self.input_size),self.set_num,False)
        self.distant=np.zeros((self.input_size,self.input_size))
        for i in range(self.input_size):
            for j in range(self.input_size):
                self.distant[i,j]=np.linalg.norm(self.global_map.syncmap[i]-self.global_map.syncmap[j])
        for _a in range(self.set_start_point.shape[0]):
            temp_set=[]
            for _b in np.argsort(self.distant[self.set_start_point[_a]])[:self.set_len]:
                self.exist_matrix[np.arange(500)[_b]]=True
                temp_set.append(np.arange(500)[_b])
            self.set.append(temp_set)
        self.exist_matrix[0]=False
        self.exist_matrix[10]=False
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set")
            for logic in range(self.exist_matrix.shape[0]):
                if not self.exist_matrix[logic]:
                    temp_set=[]
                    #print(np.argsort(self.distant[logic])[1:1+self.set_len])
                    for _b in np.argsort(self.distant[logic])[:self.set_len]:
                        self.exist_matrix[_b]=True
                        temp_set.append(np.arange(500)[_b])
                    self.set.append(temp_set)
        if np.sum(self.exist_matrix)<self.input_size:
            print(self.exist_matrix,"sce")
            print("not all the point in the set second")
    def make_set(self,x):
        set_limit=20*11
        self.set_start_point=np.random.randint(x.shape[0]-set_limit-1,size=self.set_num)
        for j in self.set_start_point:
            set_temp=[]
            start=j
            while (start-j)<20*11 and len(set_temp)<9:
                a=np.argmax(x[start])
                if a not in set_temp:
                    self.exist_matrix[a]=True
                    set_temp.append(a)
                start+=1
            self.set.append(set_temp)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set")
    def make_set_num_lim(self,x):
        set_limit=20*11
        self.set_start_point=np.random.randint(x.shape[0]-set_limit-1,size=self.set_num)
        for j in self.set_start_point:
            set_temp=[]
            start=j
            while (start-j)<20*11 and len(set_temp)<(self.set_len+1):
                a=np.argmax(x[start])
                if a not in set_temp:
                    self.exist_matrix[a]=True
                    set_temp.append(a)
                start+=1
            self.set.append(set_temp)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set first")
            self.alternate_set=[]

            for j in range(x.shape[0]):
                #print(np.argmax(x[j]))
                for i in range(self.input_size):
                    if self.exist_matrix[i]==False:
                        
                        #print(self.exist_matrix[i],"==false")
                        temp=np.argmax(x[j])
                        #print(temp,"==",i)
                        if temp==i:
                            
                            set_temp=[]
                            for c in range(set_limit):
                                temp1=np.argmax(x[j+c])
                                if temp1 not in set_temp:
                                    self.exist_matrix[temp1]=True
                                    set_temp.append(temp1)
                                if len(set_temp)==self.set_len:
                                    self.alternate_set.append(set_temp)
                                    break
            for  b in self.alternate_set:
                self.set.append(b)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set 2nd")
            
    def make_set_num_lim2(self,x):
        set_limit=20*11
        self.set_start_point=np.random.randint(x.shape[0]-set_limit*2-1,size=self.set_num)
        for i in range(int(self.set_num/2)):
            self.set_start_point[i+int(self.set_num/2)]=self.set_start_point[i]+5*11
        for j in self.set_start_point:
            set_temp=[]
            start=j
            while (start-j)<20*11 and len(set_temp)<(self.set_len+1):
                a=np.argmax(x[start])
                if a not in set_temp:
                    self.exist_matrix[a]=True
                    set_temp.append(a)
                start+=1
            self.set.append(set_temp)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set first")
            self.alternate_set=[]
            corr_start=[]
            for j in range(x.shape[0]):
                #print(np.argmax(x[j]))
                for i in range(self.input_size):
                    if self.exist_matrix[i]==False:
                        
                        #print(self.exist_matrix[i],"==false")
                        temp=np.argmax(x[j])
                        #print(temp,"==",i)
                        if temp==i:
                            corr_start.append(j+5*11)
                            set_temp=[]
                            for c in range(set_limit):
                                temp1=np.argmax(x[j+c])
                                if temp1 not in set_temp:
                                    self.exist_matrix[temp1]=True
                                    set_temp.append(temp1)
                                if len(set_temp)==self.set_len:
                                    self.alternate_set.append(set_temp)
                                    break
            print("len of corr_start= ",len(corr_start))
            for j in corr_start:
                set_temp=[]
                start=j
                while (start-j)<20*11 and len(set_temp)<(self.set_len+1):
                    a=np.argmax(x[start])
                    if a not in set_temp:
                        self.exist_matrix[a]=True
                        set_temp.append(a)
                    start+=1
                self.set.append(set_temp)
            for  b in self.alternate_set:
                self.set.append(b)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set 2nd")
    def make_set_num_lim3(self,x):
        set_limit=20*11
        self.set_start_point=np.random.randint(x.shape[0]-set_limit*2-1,size=self.set_num)
        for i in range(int(self.set_num/2)):
            self.set_start_point[i+int(self.set_num/2)]=self.set_start_point[i]+5*11
        for j in self.set_start_point:
            set_temp=[]
            start=j
            while (start-j)<20*11 and len(set_temp)<(self.set_len+1):
                a=np.argmax(x[start])
                if a not in set_temp:
                    self.exist_matrix[a]=True
                    set_temp.append(a)
                start+=1
            self.set.append(set_temp)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set first")
            self.alternate_set=[]
            corr_start=[]
            for j in range(x.shape[0]):
                #print(np.argmax(x[j]))
                for i in range(self.input_size):
                    if self.exist_matrix[i]==False:
                        
                        #print(self.exist_matrix[i],"==false")
                        temp=np.argmax(x[j])
                        #print(temp,"==",i)
                        if temp==i:
                            corr_start.append(j+4*11)
                            corr_start.append(j-4*11)
                            corr_start.append(j+2*11)
                            corr_start.append(j-2*11)
                            corr_start.append(j+3*11)
                            corr_start.append(j-3*11)
                            corr_start.append(j-5*11)
                            corr_start.append(j+5*11)
   #                         corr_start.append(j-6*11)
  #                          corr_start.append(j+6*11)
 #                           corr_start.append(j-7*11)
#                            corr_start.append(j+7*11)
                            set_temp=[]
                            for c in range(set_limit):
                                temp1=np.argmax(x[j+c])
                                if temp1 not in set_temp:
                                    self.exist_matrix[temp1]=True
                                    set_temp.append(temp1)
                                if len(set_temp)==self.set_len:
                                    self.alternate_set.append(set_temp)
                                    break
            print("len of corr_start= ",len(corr_start))
            for j in corr_start:
                set_temp=[]
                start=j
                while (start-j)<20*11 and len(set_temp)<(self.set_len+1):
                    a=np.argmax(x[start])
                    if a not in set_temp:
                        self.exist_matrix[a]=True
                        set_temp.append(a)
                    start+=1
                self.set.append(set_temp)
            for  b in self.alternate_set:
                self.set.append(b)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set 2nd")
    def make_set_sequence(self,x):
        set_temp=[]
        start=0
        for a in range(x.shape[0]):
            if (start-a)>20*11 or len(set_temp)>9:
                start=a
                self.set.append(set_temp)
                set_temp=[]
            b=np.argmax(x[a])
            if b not in set_temp:
                self.exist_matrix[b]=True
                set_temp.append(b)
        if np.sum(self.exist_matrix)<self.input_size:
            print("not all the point in the set")
    def input(self,x):
        #self.make_set_with_global_map(x)
        self.make_set_num_lim3(x)
        #self.set=self.set+self.set
        for i in trange(len(self.set)):
            """change the algorithm paramater here"""

            map_dimensions=self.dim
            #print(len(self.set[i]))
            algorithm_class=self.map_algorithm(len(self.set[i]),map_dimensions
                                                ,adaptation_rate=0.001*len(self.set[i]))
            #print(x[:,self.set[i]].shape)
            algorithm_class.input(x[:,self.set[i]])
            self.maps.append(algorithm_class)

    def organize(self,eps=0.5):
        miss=0
        self.meta_label=np.zeros(self.input_size,dtype=int)
        self.correlation_table=np.zeros((self.input_size,self.input_size))
        self.label=np.zeros(self.input_size)
        #calculate voting result
        for map_i in range(len(self.maps)):
            #if map_i==0 or map_i==1 or map_i==2:
            #print(map_i," th lable is: ",self.maps[map_i].organize())
            label=self.maps[map_i].organize(eps)
            for i in range(label.shape[0]):
                for j in range(label.shape[0]-i-1):
                    k=j+i+1
                    if label[i]==label[k]:
                        self.correlation_table[self.set[map_i][i],self.set[map_i][k]]+=1
                        self.correlation_table[self.set[map_i][k],self.set[map_i][i]]+=1
                    elif label[i]!=label[k]:
                        self.correlation_table[self.set[map_i][i],self.set[map_i][k]]-=1
                        self.correlation_table[self.set[map_i][k],self.set[map_i][i]]-=1
                        

                #print(self.correlation_table)
        #make label
        label_count=1
        for a in range(self.input_size):
            for b in range(self.input_size):
                #print(a,b,"a,b")
                #print(self.correlation_table[a,b],"corr")
                if self.correlation_table[a,b]>0:
                    if self.label[a]!=0:
                        self.label[self.label==self.label[a]]=label_count
                    if self.label[b]!=0:
                        self.label[self.label==self.label[b]]=label_count
                    self.label[a]=label_count
                    self.label[b]=label_count
                    #print(self.label,"label")
                    label_count+=1 
        return self.label
    def plot_animation(self,env):
        tru_label=env.trueLabel()
        for i,map_ in enumerate(self.maps):
            print(tru_label[self.set[i]])
            print(map_.organize())
            name=env.name
            for a in self.set[i]:
                name+=str(a)+","
            map_.plot_animation(name,tru_label[self.set[i]])
        #print(self.correlation_table)
                #print(temp,"temp")
    def evaluation(self,true_label):
        label=self.organize()
        return normalized_mutual_info_score(label,true_label)
