import tune_wishart_wrap
import numpy as np

from tqdm import tqdm
import time

import itertools
import lib

import os


biasl = [0.0,6.0,8.0,10.0,12.0]

Nl = [60,100,120,140,160]
T = [50,100,300,500,1000,2000,3000,4000]


pt_device = 'cuda'
#pt_device = 'cpu'

alphatxt = '0.80'

fMHv = {}
fMHv['CACm'] = 0.0
fMHv['AIM'] = 0.0
fMHv['CAC'] = 0.0
fMHv['MHCACm'] = 0.1
fMHv['MHCAC'] = 0.1


####################################################
#SOLVER

solvertypel = ['MHCACm','CACm','AIM','CAC']

debug = 0



for bias in biasl:


    data = {}
    if bias==0:
        data['datatype'] = 'unbias'
    else:
        data['datatype'] = 'bias'

    if data['datatype'] == 'unbias':
        data['D_WPE'] = 1 # 3
        data['R_WPE'] = -1 # 6
        
        data['bias'] = 0.0
        
    elif data['datatype'] == 'bias':
        #first GS is close to ferromagnetic
        data['D1_WPE'] = 1
        data['R1_WPE'] = 3
        
        #second GS is random
        data['D2_WPE'] = 1
        data['R2_WPE'] = -1
        
        data['bias'] = bias


        data['D_WPE'] = data['D1_WPE'] + data['D2_WPE']

        bias = data['bias']

    ####################################################

    total = len(Nl) * len(T)

    afolder_name = []

    for solvertype in solvertypel:

        fMH = fMHv[solvertype]

        #folder_name = lib.create_timestamped_folder(solvertype)
        folder_name = f"{solvertype}_{bias}_{fMH}"
        
        # Check if the folder already exists
        if not os.path.exists(folder_name):
            # Create the folder if it doesn't exist
            os.makedirs(folder_name)
            print(f"Folder created: {folder_name}")
        else:
            print(f"Folder already exists: {folder_name}")
            
        
        # Use tqdm with itertools.product to iterate over all combinations
        for combination in tqdm(itertools.product(Nl, np.flip(T)), desc="Iterating combinations"):
        
            if solvertype=='MHCACm':
                PARAM_NAMES = ["beta","kappa","lamb","xi","gamma","a"]
                
                #sampling (fully benchmarked)
                x = np.log([0.1,0.1,1.0,0.1,1.0,1.0]) #lamb=1 important
                hyperparams = {'T': combination[1],'doa':1,'dosampling':0,'fMH':fMH}
                fMH = hyperparams['fMH']
                
            if solvertype=='MHCAC':
                PARAM_NAMES = ["beta","kappa","lamb","xi","a"]
                
                x = np.log([0.1,0.1,0.1,0.1,1.0])
                hyperparams = {'T': combination[1],'doa':0,'dosampling':0,'fMH':1.0}
                
                fMH = hyperparams['fMH']
                
            if solvertype=='CACm':
                PARAM_NAMES = ["beta","lamb","xi","gamma","a"]
                #x = np.log([0.1,0.1,0.1,1.0,1.0])
                x = np.log([0.1,1.0,0.1,1.0,1.0])
                hyperparams = {'T': combination[1]}
            if solvertype=='CAC':
                PARAM_NAMES = ["beta","lamb","xi","a"]
                x = np.log([0.1,1.0,0.1,1.0])
                hyperparams = {'T': combination[1]}
            if solvertype=='AIM':
                PARAM_NAMES = ["beta","lamb","gamma"]
                x = np.log([0.1,1.0,1.0])
                hyperparams = {'T': combination[1]}
                
            #generate problem instance
            instance = {'alphatxt':alphatxt, 'N': combination[0]}
            flags = {'savetraj':1, 'pt_device': pt_device,'solvertype':solvertype}
            #tunerparams = {'nsamp_max': 5000, 'R': 200}
            tunerparams = {'nsamp_max': 100000, 'R': 200}
            #x = [0.1,1.0,0.1,1.0]
            
            tune_wishart_wrap.tune_wishart(folder_name,instance,hyperparams,PARAM_NAMES,x,flags,tunerparams,data)
            
            if 0:
                if combination[1]==np.min(hyperparams['T']):
                    file_name = f"wishart_{combination[0]}_{alphatxt}_{np.max(hyperparams['T'])}.txt"
                else:
                    file_name = f"wishart_{combination[0]}_{alphatxt}_{combination[1]}.txt"
                    
                p0, opt_params, pvec = lib.read_file(folder_name, file_name)
                x = np.log(opt_params)
                
        afolder_name.append(folder_name)


    ####################################################
