import argparse
import random
from urllib import request
from imports import *
from utils import *
from functionset import *
import functionset
from tqdm import tqdm
from matplotlib import pyplot as plt
import os
from decn import *
import imports
from convex_torch import *
import time


def set_random_seed(seed=42):
    torch.manual_seed(seed)#torch的cpu随机性
    torch.cuda.manual_seed_all(seed)#torch的gpu随机性
    torch.backends.cudnn.benchmark = False#保证gpu每次都选择相同的算法，但是不保证该算法是deterministic的。
    torch.backends.cudnn.deterministic = True#紧接着上面，保证算法是deterministic的。
    np.random.seed(seed)#np的随机性。
    random.seed(seed)#python的随机性。
    os.environ['PYTHONHASHSEED'] = str(seed)#设置python哈希种子，有人不知道这个是干啥的，
    #python里面有很多使用哈希算法完成的操作，例如对于一个数字的列表，使用set()来去重。
    #大家应该经历过，得到的结果中，顺序可能不一样，例如(1,2,3)(3,2,1)。
    #有时候需要在终端就把这个固定执行，到脚本实行有可能会迟。

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"



MODEL_DICT={
    'decnws3':DECNws3,
    'decnws30':DECNws30,
    'decnnws15':DECNnws15,
}


def Train(maxEpoch=10000,
          params=None,saveStep=50,blackbox=True,idoffset=None):
    
    dim=params['dim']
    w=params['w']
    h=params['h']
    batchSize=params['batchSize']
    T=params['T']
    funSet=params['funSet']
    requireOffset=params['requireOffset']
    detn=MODEL_DICT[params['modelname']](batchSize=batchSize,w=w,h=h,useRepaire=params['useRepaire']).to(DEVICE)
    
    if params['reload']:
        try:
            detn.load_state_dict(torch.load(params['ckpt']))
        except:
            print('no checkpoint file found!')
            
    bar=tqdm(range(maxEpoch))
    lr=params['lr']
    lossList=params['losslist']
    opt=torch.optim.Adam(detn.parameters(),lr=lr)
    plt.figure(figsize=(15,10))
    for epoch in bar:        
        isoffsetFun=False
        if (epoch+1)%100==0:
            for param_group in opt.param_groups:
                param_group['lr'] *=0.9
                
        if (epoch+1)%T==0 or epoch==0:
            isoffsetFun=True
        for fun in  funSet:
            imports.TRAINFUN=fun
            if isoffsetFun and requireOffset:
                if idoffset is None:
                    fun=reOffSet(fun,False)
                else:
                    fun['offset']=np.random.choice(idoffset)
            opt.zero_grad()
            batchPop=sampleBatchPop(batchSize,w,h,dim,fun['xlb'],fun['xub'])
            batchPop=calFitness(batchPop,fun)
            batchOffPop=detn(batchPop,fun['xlb'],fun['xub'])
            if blackbox:
                loss=lossFun(batchPop,batchOffPop)
            else:
                loss=convexLoss(batchOffPop)
            loss.backward()
            nn.utils.clip_grad_norm_(detn.parameters(), 10, norm_type=2)
            opt.step()
        
        tloss=None
        batchChrom=sampleBatchPop(batchSize,w,h,dim,fun['xlb'],fun['xub'])
        for fun in  funSet:
            imports.TRAINFUN=fun
            batchPop=calFitness(batchChrom,fun)
            batchOffPop=detn(batchPop,fun['xlb'],fun['xub'])
            if blackbox:
                loss=lossFun(batchPop,batchOffPop)
            else:
                loss=convexLoss(batchOffPop)
            if tloss is None:
                tloss=loss
            else:
                tloss=tloss+loss
        tloss=tloss/len(funSet)
        opt.zero_grad()
        tloss.backward()
        nn.utils.clip_grad_norm_(detn.parameters(), 10, norm_type=2)
        opt.step()    
        lossList.append(tloss.item())
        bar.set_description('epoch | %d,TrainLoss:%.8f'%(epoch,tloss.item()))     
            
        if (epoch+1)%saveStep==0:
            torch.save(detn.state_dict(),params['ckpt'])
            params['losslist']=lossList
            dump(params,params['paramPath'])
        plt.cla()
        plt.title('TrainLoss')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.plot(lossList,color='red',label='loss')
        plt.legend()
        plt.savefig(params['figPath'])
        
        
        
def eval( params=None,runs=10,fun=None,reload=True):
    
    imports.TRAINFUN=fun
    dim=params['dim']
    w=params['w']
    h=params['h']
    batchSize=params['batchSize']
    requireOffset=params['requireOffset']
    detn=MODEL_DICT[params['modelname']](batchSize=batchSize,w=w,h=h,useRepaire=params['useRepaire']).to(DEVICE)
    if reload:
        try:
            detn.load_state_dict(torch.load(params['ckpt']))
            detn=detn.eval()
        except:
            print('no checkpoint file found!')
    bar=tqdm(range(runs),ncols=100)
    try:
        evals=params['eval']
    except:
        evals=dict()
        params['eval']=evals
    tmp=dict()
    fitlist=[]
    for i in bar:
        batchPop=sampleBatchPop(batchSize,w,h,dim,fun['xlb'],fun['xub'])
        if requireOffset:
            fun=reOffSet(fun,zeroOffset=False)
        else:
            fun=reOffSet(fun,zeroOffset=True)    
        
        imports.TRAINFUN=fun
        batchPop=calFitness(batchPop,fun)
        offPop=detn(batchPop,fun['xlb'],fun['xub'])
        offfit=offPop[:,0,:,:]
        fit=torch.min(offfit).item()
        fitlist.append(fit)
        bar.set_description('epoch %d bestfit is :%.5f'%(i+1,fit))
    tmp['ave']=np.mean(fitlist)
    tmp['best']=min(fitlist)
    tmp['std']=np.std(fitlist)
    evals[fun['id']]=tmp
    print('fun %d best:%.5f mean:%.2E(%.2E)'%(fun['id'],tmp['best'], tmp['ave'],tmp['std']))
    dump(params,params['paramPath'])
        




      
def eval_protein( params=None,runs=10,fun=None):
    imports.TRAINFUN=fun
    dim=params['dim']
    w=params['w']
    h=params['h']
    batchSize=params['batchSize']
    requireOffset=params['requireOffset']
    detn=MODEL_DICT[params['modelname']](batchSize=batchSize,w=w,h=h,useRepaire=params['useRepaire']).to(DEVICE)
    try:
        detn.load_state_dict(torch.load(params['ckpt']))
        detn=detn.eval()
    except:
        print('no checkpoint file found!')
    bar=tqdm(range(runs))
    try:
        evals=params['eval']
    except:
        evals=dict()
        params['eval']=evals
    tmp=dict()
    fitlist=[]
    for i in bar:
        batchPop=sampleBatchPop(batchSize,w,h,dim,fun['xlb'],fun['xub'])
        batchPop=calFitness(batchPop,fun)
        offPop=detn(batchPop,fun['xlb'],fun['xub'])
        offfit=offPop[:,0,:,:]
        fit=torch.min(offfit).item()
        fitlist.append(fit)
        bar.set_description('epoch %d bestfit is :%.5f'%(i+1,fit))
    tmp['ave']=np.mean(fitlist)
    tmp['best']=min(fitlist)
    tmp['std']=np.std(fitlist)
    evals[fun['id']]=tmp
    print('fun %d best:%.5f mean:%.5f std:%.5f'%(fun['id'],tmp['best'], tmp['ave'],tmp['std']))
    dump(params,params['paramPath'])
        








def blackbox_train(maxEpoch=1000):
    # set_random_seed(42)
    dim=10
    functionset.f1w=torch.nn.Sequential(
        torch.nn.Linear(dim,32,bias=False),
        torch.nn.Linear(32,1,bias=False),
    ).to(DEVICE)
    
    funs={
1:{ 'id':1,
    'fun':fun1,
    'xub':10,
    'xlb':-10,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':None
},
2:{'id':2,
    'fun':fun2,
    'xub':10,
    'xlb':-10,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':None
},
3:{'id':3,
    'fun':fun3,
    'xub':10,
    'xlb':-10,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':None
}
,
4:{'id':4,
    'fun':fun4,
    'xub':100,
    'xlb':-100,
    'bub':50,
    'blb':-50,
    'dim':dim,
    'bias':None
},

5:{'id':5,
    'fun':fun5,
    'xub':100,
    'xlb':-100,
    'bub':50,
    'blb':-50,
    'dim':dim,
    'bias':None
},

6:{'id':6,
    'fun':fun6,
    'xub':100,
    'xlb':-100,
    'bub':50,
    'blb':-50,
    'dim':dim,
    'bias':None
},

7:{'id':7,
    'fun':fun7,
    'xub':5,
    'xlb':-5,
    'bub':2.5,
    'blb':-2.5,
    'dim':dim,
    'bias':None
},

8:{'id':8,
    'fun':fun8,
    'xub':600,
    'xlb':-600,
    'bub':300,
    'blb':-300,
    'dim':dim,
    'bias':None
}
,
9:{'id':9,
    'fun':fun9,
    'xub':32,
    'xlb':-32,
    'bub':16,
    'blb':-16,
    'dim':dim,
    'bias':None
}

}
    
    funset=[funs[1],funs[2],funs[3]]     #,funs[4],funs[5],funs[6]]
    
    params={
        'dim':dim, #Dimensions of the problem
        'lr':0.001,
        'losslist':[],
        'w':10,
        'h':10,
        'batchSize':64,
        'T':1,
        'funSet':funset,
        'requireOffset':True,
        'reload':False,
        'modelname':'decnws3',
        'paramPath':'./CKPT/dim2.pkl',
        'figPath':'./OUTPUT/figs/dim2.png',
        'ckpt':'./CKPT/dim2.pth',
        'useRepaire':True
    }
    

    Train(maxEpoch=maxEpoch,params=params)
    params['requireOffset']=False
    params['batchSize']=256
    for i in range(4,10):
        eval(params=params,runs=200,fun=funs[i],reload=False)



def ConvexL2O_Train(maxEpoch=1000):
    dim=12
    problem=Protein_dock(100,dim=12)
    fun={ 'id':1,
    'fun':problem.forward,
    'xub':1,
    'xlb':-1,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':0
}
    
    funset=[fun]
    
    params={
        'dim':dim, #Dimensions of the problem
        'lr':0.01,
        'losslist':[],  
        'w':10,
        'h':10,   #popSize=w*h
        'batchSize':1,#in this case, batchsize is set to 1
        'T':50,
        'funSet':funset,#training function set
        'requireOffset':True,#Whether the bias needs to be regenerated for the function every T  epoches in training
        'reload':False, #Whether to load model weights
        'modelname':'decnnws15', #type of model
        'paramPath':'./CKPT/dim2.pkl',#path to save Experimental parameters
        'figPath':'./OUTPUT/figs/dim2.png',#path to save the image of loss function
        'ckpt':'./CKPT/dim2.pth',#path to save model parameters
        'useRepaire':False  #Whether to fix out-of-bounds solutions
    }
    

    Train(maxEpoch=maxEpoch,params=params,blackbox=True,idoffset=problem.get_len())
    testset=['1ATN_7','2JEL_1','7CEI_1']
    for i in testset:
        problem.set_training(False)
        problem.set_testfun(i)
        print('testing on %s ......'%i)
        eval_protein(params=params,runs=200,fun=fun)


def parsargs():
    parser = argparse.ArgumentParser()
    parser.add_argument("-problem", help="problem to test,", default='functions')
    parser.add_argument("-maxEpoch", help="maximum number of iterations", default=1000,type=int)
    args = parser.parse_args()
    return args


if __name__=='__main__':
    
    args=parsargs()
    problemname=args.problem
    maxEpoch=args.maxEpoch
    if  problemname=='functions':
        blackbox_train(maxEpoch=maxEpoch)
    elif problemname=='protein':
        ConvexL2O_Train(maxEpoch=maxEpoch)
    else:
        print('no such problem!')
        

    


    
    

    


