from pickle import FALSE
import random
import time
from tkinter.tix import Tree
from detn import *
from imports import *
from utils import *
from functionset import *
import functionset
from tqdm import tqdm
from matplotlib import pyplot as plt
import os
from convex_torch import *
# from torch.utils.data import DataLoader
import argparse

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"



def set_random_seed(seed=42):
    torch.manual_seed(seed)#torch的cpu随机性
    torch.cuda.manual_seed_all(seed)#torch的gpu随机性
    torch.backends.cudnn.benchmark = False#保证gpu每次都选择相同的算法，但是不保证该算法是deterministic的。
    torch.backends.cudnn.deterministic = True#紧接着上面，保证算法是deterministic的。
    np.random.seed(seed)#np的随机性。
    random.seed(seed)#python的随机性。
    os.environ['PYTHONHASHSEED'] = str(seed)#设置python哈希种子，有人不知道这个是干啥的，
    #python里面有很多使用哈希算法完成的操作，例如对于一个数字的列表，使用set()来去重。
    #大家应该经历过，得到的结果中，顺序可能不一样，例如(1,2,3)(3,2,1)。
    #有时候需要在终端就把这个固定执行，到脚本实行有可能会迟。

modelDict={
    'detn':DETN,
    'detnws3onlymsha':DETNws3OnlyMhsa
}









def DetnTrain(maxEpoch=10000,
          params=None,saveStep=50,idoffset=None):
    
    
    num_heads=params['num_heads']
    dim=params['dim']
    hidden_dim=params['hidden_dim']
    popSize=params['popSize']
    batchSize=params['batchSize']
    T=params['T']
    funSet=params['funSet']
    requireOffset=params['requireOffset']
    ems=params['ems']
    ws=params['ws']
    detn=modelDict[params['modelname']](num_heads=num_heads,dim=dim,
                                       hidden_dim=hidden_dim,popSize=popSize,ems=ems,ws=ws).to(DEVICE)
    if params['reload']:
        try:
            detn.load_state_dict(torch.load(params['ckpt']))
        except:
            print('no checkpoint file found!')
    bar=tqdm(range(maxEpoch),ncols=100)
    lr=params['lr']
    lossList=params['losslist']
    opt=torch.optim.Adam(detn.parameters(),lr=lr)
    plt.figure(figsize=(15,10))
    funSet=np.array(funSet)
    minloss=None
    funs=None
    fun=None
    for epoch in bar:        
        if (epoch+1)%100==0:
            for param_group in opt.param_groups:
                param_group['lr'] *=0.9
                
        
        if idoffset is None:
            funs=funSet[np.random.randint(0,len(funSet),6)]
        else:
            if (epoch+1) % T==0 or epoch==0:
                fun=funSet[0]
                fun['offset']=np.random.choice(idoffset)
                funs=[fun]
        tloss=None
        for fun in  funs:
            batchPop=sampleBatchPop(batchSize,popSize,dim,fun['xlb'],fun['xub'])
            batchOffPop=detn(batchPop,fun,fun['xlb'],fun['xub'])
            loss=lossFunc(batchPop,batchOffPop,fun)
            if tloss is None:
                tloss=loss
            else:
                tloss+=loss
        tloss/=len(funs)
        opt.zero_grad()
        tloss.backward()
        nn.utils.clip_grad_norm_(detn.parameters(), 10, norm_type=2)
        opt.step()    
        
        lossList.append(tloss.item())
        if (minloss is None or tloss.item()<minloss) and epoch/maxEpoch>0.8:
            minloss=tloss.item()
            torch.save(detn.state_dict(),params['ckpt'])
            params['losslist']=lossList
            dump(params,params['paramPath']) 
        if not minloss is None:            
            bar.set_description('epoch | %d,TrainLoss:%.8f,savedMinloss:%.5f'%(epoch,tloss.item(),minloss))     
        else:
            bar.set_description('epoch | %d,TrainLoss:%.8f,savedMinloss:%s'%(epoch,tloss.item(),'None'))  
        
        plt.cla()
        plt.title('TrainLoss')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.plot(lossList,color='red',label='loss')
        plt.legend()
        plt.savefig(params['figPath'])
        
        
        

def eval( params=None,runs=10,fun=None,reload=True,idx=None):
    
    num_heads=params['num_heads']
    dim=params['dim']
    hidden_dim=params['hidden_dim']
    popSize=params['popSize']
    batchSize=params['batchSize']
    requireOffset=params['requireOffset']
    ems=params['ems']
    ws=params['ws']
    detn=modelDict[params['modelname']](num_heads=num_heads,dim=dim,
                                       hidden_dim=hidden_dim,popSize=popSize,ems=ems,ws=ws).to(DEVICE)
    if reload:
        try:
            detn.load_state_dict(torch.load(params['ckpt']))
            detn=detn.eval()
        except:
            print('no checkpoint file found!')
    bar=(range(runs))
    try:
        evals=params['eval']
    except:
        evals=dict()
        params['eval']=evals
    tmp=dict()
    fitlist=[]
    for i in bar:
        batchPop=sampleBatchPop(batchSize,popSize,dim,fun['xlb'],fun['xub'])
        # fun['bias']=torch.load('./dim10bias.pth').to(DEVICE)
        if idx is not None:
            pass
        else:
            if requireOffset:
                fun=reOffSet(fun,zeroOffset=False)
            else:
                fun=reOffSet(fun,zeroOffset=True)
                
        if i==0:
            try:
                batchPop=torch.from_numpy(np.load('./testpopdim%d.npy'%params['dim'])).to(DEVICE)
            except:
                batchPop=sampleBatchPop(1,popSize,dim,fun['xlb'],fun['xub'])
                np.save('./testpopdim%d.npy'%(params['dim']),batchPop.detach().cpu().numpy())
            offPop=detn(batchPop,fun,fun['xlb'],fun['xub'],recordFit='detn_em%d_ws%d_f_%d'%(params['ems'],params['ws'],fun['id']-3))
        else:
            offPop=detn(batchPop,fun,fun['xlb'],fun['xub'])
        fit=torch.min(fun['fun'](offPop,fun['bias'])).item()
        fitlist.append(fit)
        # bar.set_description('eval fun%d'%fun['id'])
    tmp['ave']=np.mean(fitlist)
    tmp['best']=min(fitlist)
    tmp['std']=np.std(fitlist)
    evals[fun['id']]=tmp
    print('fun %d best:%.5f mean:%.2E(%.2E)'%(fun['id'],tmp['best'], tmp['ave'],tmp['std']))
    dump(params,params['paramPath'])
        

def get_TrianFunset(baseFunset,n=10):
    funset=[]
    for i in range(n):
        for fun in baseFunset:
            funset.append(reOffSet(fun,False))
    return funset



def blackbox(maxEpoch=100):
    dim=10
    functionset.f1w=torch.nn.Sequential(
        torch.nn.Linear(dim,32,bias=False),
        torch.nn.Linear(32,1,bias=False),
    ).to(DEVICE)
    funs={
1:{ 'id':1,
    'fun':fun1,
    'xub':10,
    'xlb':-10,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':None
},
2:{'id':2,
    'fun':fun2,
    'xub':10,
    'xlb':-10,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':None
},
3:{'id':3,
    'fun':fun3,
    'xub':10,
    'xlb':-10,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':None
}
,
4:{'id':4,
    'fun':fun4,
    'xub':100,
    'xlb':-100,
    'bub':50,
    'blb':-50,
    'dim':dim,
    'bias':None
},

5:{'id':5,
    'fun':fun5,
    'xub':100,
    'xlb':-100,
    'bub':50,
    'blb':-50,
    'dim':dim,
    'bias':None
},

6:{'id':6,
    'fun':fun6,
    'xub':100,
    'xlb':-100,
    'bub':50,
    'blb':-50,
    'dim':dim,
    'bias':None
},

7:{'id':7,
    'fun':fun7,
    'xub':5,
    'xlb':-5,
    'bub':2.5,
    'blb':-2.5,
    'dim':dim,
    'bias':None
},

8:{'id':8,
    'fun':fun8,
    'xub':600,
    'xlb':-600,
    'bub':300,
    'blb':-300,
    'dim':dim,
    'bias':None
}
,
9:{'id':9,
    'fun':fun9,
    'xub':32,
    'xlb':-32,
    'bub':16,
    'blb':-16,
    'dim':dim,
    'bias':None
}

}
    
    
    
    basefunset=[funs[1],funs[2],funs[3]]#,funs[4],funs[5],funs[6]
    
    funset=get_TrianFunset(basefunset,1000)
    
    params={
        'num_heads':1,    # A fixed parameter, do not change it
        'ems':30,                #  Number of OBs
        'ws':True,              #  Whether OBs share parameters
        'dim':dim,           
        'lr':0.01,
        'losslist':[],
        'hidden_dim':200,  # Dimension of SA's QK projection
        'popSize':100,        
        'batchSize':64,
        'T':50,           #A deprecated parameter
        'funSet':funset,  #training fun set
        'requireOffset':True,#A deprecated parameter
        'reload':False,   #Whether to load model weights
        'modelname':'detn',    # A fixed parameter, do not change it 
        'paramPath':'./ckpt/dim10ws3.pkl',  #path to save Experimental parameters
        'figPath':'./imgs/dim10ws3.png',  #path to save the image of loss function
        'ckpt':'./ckpt/dim10ws3.pth', #path to save model parameters
    }
    

    DetnTrain(maxEpoch=maxEpoch,params=params)
    for i in range(4,10):
        params['requireOffset']=False
        params['batchSize']=512
        eval(params=params,runs=400,fun=funs[i],reload=True)




def ConvexL2O_Train(maxEpoch=1000):
    dim=12
    popSize=128
    problem=Protein_dock(popSize,dim=12)

    fun={ 'id':1,
    'fun':problem.forward,
    'xub':1,
    'xlb':-1,
    'bub':10,
    'blb':-10,
    'dim':dim,
    'bias':0
}
    
    funset=[fun]
    
    params={
        'num_heads':2,
        'ems':5,
        'ws':False,
        'dim':dim,
        'lr':0.000001,
        'losslist':[],
        'hidden_dim':200,
        'popSize':128,
        'batchSize':1,  #in this case, batchsize is set to 1
        'T':50,
        'funSet':funset,
        'requireOffset':True,
        'reload':False,
        'modelname':'detn',
        'paramPath':'./ckpt/dim10ws30_protein.pkl',
        'figPath':'./imgs/dim10ws30_protein.png',
        'ckpt':'./ckpt/dim10ws30_protein.pth',
    }
    
    

    DetnTrain(maxEpoch=maxEpoch,params=params,idoffset=problem.get_len())
    fun['requireOffset']=False
    testset=['1ATN_7','2JEL_1','7CEI_1']
    for i in testset:
        problem.set_training(False)
        problem.set_testfun(i)
        print('testing on %s ......'%i)
        eval(params=params,runs=200,fun=fun,reload=True,idx=True)




def parsargs():
    parser = argparse.ArgumentParser()
    parser.add_argument("-problem", help="problem to test,", default='functions')
    parser.add_argument("-maxEpoch", help="maximum number of iterations", default=1000,type=int)
    args = parser.parse_args()
    return args

if __name__=='__main__':
    args=parsargs()
    problemname=args.problem
    maxEpoch=args.maxEpoch
    if  problemname=='functions':
        blackbox(maxEpoch=maxEpoch)
    elif problemname=='protein':
        ConvexL2O_Train(maxEpoch=maxEpoch)
    else:
        print('no such problem!')
        
    


    
    

    


