from functools import cmp_to_key
import torch
import pickle
from imports import DEVICE
import numpy as np


def sampleBatchPop(batchSize,w,h,dim,xlb,xub):
    batchpop=torch.rand((batchSize,dim,w,h)).to(DEVICE).float()
    batchpop=batchpop*(xub-xlb)+xlb
    return batchpop


def lossFun(S0,S1):
    f0=torch.mean(S0[:,0,:,:])
    f1=torch.mean(S1[:,0,:,:])
    loss=(f1-f0)/torch.abs(f0)
    return loss


def convexLoss(S):
    f=0.5* torch.mean(S[:,0,:,:])**2
    return f


def sortIndiv(batchPop):
    '''
    作用：
    将一批种群中的个体按照 fitness维度的值来排序号
    
    输入：
    batchPop:一批种群，维度为（batchSize,dim+1,L*L）
    返回:
    排好序的（batch,dim+1,w*h的矩阵）
    '''
    b,d,w,h=batchPop.shape
    fitness=batchPop[:,0,:,:]
    fitness=fitness.view(b,w*h)
    _,fit=torch.sort(fitness,dim=1) #b,n
    batchPop=batchPop.view(b,d,-1).permute(0,2,1) #b,n,dim
    y=torch.zeros_like(batchPop)
    for index,pop in enumerate(batchPop):
        pop=batchPop[index]  #n,dim
        y[index]=torch.index_select(pop,0,fit[index])
    y=y.permute(0,2,1).view(b,d,w,h)
    batchPop=y
    return batchPop




def calFitness(batchChrom,fun):
    '''
    计算一个batch的种群的适应度
    '''
    b,dim,w,h=batchChrom.shape
    batchChrom=batchChrom.view(b,dim,-1).permute(0,2,1)  #b,n,dim
    fitness=fun['fun'](batchChrom,fun['bias']) #b,n
    batchChrom=batchChrom.permute(0,2,1).view(b,dim,w,h)
    fitness=fitness.view(b,1,w,h)
    batchPop=[]
    for i in range(b):
        batchPop.append(torch.cat((fitness[i],batchChrom[i]),dim=0))
    batchPop=torch.stack(batchPop)

        
    return batchPop


def reOffSet(fun,zeroOffset=True):
    ub=fun['bub']
    lb=fun['blb']
    dim=fun['dim']
    if not zeroOffset:
        bias=torch.rand(dim)*(ub-lb)+lb
    else:
        bias=torch.zeros(dim)
    fun['bias']=bias.to(DEVICE)
    return fun


def dump(file,path):
    with open(path,'wb') as f:
        pickle.dump(file,f)

def load(path):
    with open(path,'rb') as f:
        file=pickle.load(f)
    return file



def fun2(x,bias=None):
    if not bias is None:
        bias=bias.view(-1)
        z=x-bias
    else:
        z=x
    sc=torch.sum(torch.abs(z),dim=2)
    return sc


fun={
    'fun':fun2,
    'bias':None
}

if __name__=='__main__':
    a=torch.from_numpy(np.array([[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]])).float()
    b=a.shape[0]
    n=a.shape[1]*a.shape[2]
    print(torch.mean(a),(torch.sum(a)/n)/b)
    