from image_param_list import *
import subprocess
import argparse
import os


param_list=params_mlnp
args,_,_=param_list()
args.device = 'cuda'

import random
import numpy as np
import torch
import torch.optim as optim
torch.set_num_threads(4)
from train import *
from net import *
from test import evalution

def compute_result(args,
                   meta_net,
                   check_lvm,
                   image_data,
                   rand_eval,
                   writer=1,regularization=False,seed=7):
    train_loader,val_loader,eval_loader=cifar10_metadataset()
    dim = 32
    args.y_dim = 3
    args.is_rfm = False
    args.type = check_lvm
    meta_net=meta_net(args).to(args.device)
    optimizer = optim.Adam(meta_net.parameters(), lr=5e-4)

    meta_tr_results, meta_te_nll_results, meta_te_mse_results = run_tr_te(args=args,
                                                                          meta_net=meta_net,
                                                                          dim=dim,
                                                                          net_optim=optimizer,
                                                                          train_loader=train_loader,
                                                                          val_loader=val_loader,
                                                                          eval_loader=eval_loader,
                                                                          check_lvm=check_lvm,
                                                                          data = image_data,
                                                                          rand_eval=rand_eval,
                                                                          writer=writer,seed=seed)


    alpha_te = [0.,0.5,0.7,0.9]
    meta_te_results = np.zeros(shape=(4,4))
    meta_te_results_std = np.zeros(shape=(4,4))
    for i in range(len(alpha_te)):
        mid_res = []
        for j in [1,2,3,4,5]:
            global_seed = j
            np.random.seed(global_seed)
            torch.manual_seed(global_seed)
            random.seed(global_seed)
            torch.cuda.manual_seed(global_seed)
            torch.cuda.manual_seed_all(global_seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            mid_res.append(evalution(check_lvm=check_lvm,image_data=image_data,alpha=alpha_te[i],load_writer=writer,writer=alpha_te[i],seed=seed,save=True))
        mid_res = np.array(mid_res)
        meta_te_results[i:i+1,:] = np.mean(mid_res,axis=0)
        meta_te_results_std[i:i+1,:] = np.std(mid_res,axis=0)
    points = [10,100,300,500,700,800]
    meta_te_results_fix = np.zeros(shape=(4,6))
    for i in range(len(points)):
        if points[i]==800 and dim!=32:
            break
        mid_res = []
        for j in [1,2,3,4,5]:
            global_seed = j
            np.random.seed(global_seed)
            torch.manual_seed(global_seed)
            random.seed(global_seed)
            torch.cuda.manual_seed(global_seed)
            torch.cuda.manual_seed_all(global_seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            mid_res.append(evalution(check_lvm=check_lvm,image_data=image_data,alpha=0.,\
             load_writer=writer,writer=0.,num_c_points=points[i],seed=seed))
        mid_res = np.array(mid_res)
        meta_te_results_fix[:,i:i+1] = np.mean(mid_res,axis=0).reshape(4,1)
    return meta_te_results,meta_te_results_fix,meta_te_results_std

seed = 2
check_lvm_list=[
           'OS-NPs'
           ]
#check_lvm_list=[
#           'IWNPs','CVaR-NPs','GDRO-NPs','TDRO-NPs','OS-NPs'
 #          ]
data_list = ['CIFAR10']
if args.writer == None:
    writer = 10271
else:
    writer = args.writer
results = np.zeros(shape=(len(check_lvm_list),2*len(data_list)))
for image_data in data_list:
    headers = [
        'Method','llt_avg','llc_avg','llp_avg','mse_avg','llt_0.5','llc_0.5','llp_0.5','mse_0.5','llt_0.7','llc_0.7','llp_0.7','mse_0.7','llt_0.9','llc_0.9','llp_0.9','mse_0.9'
    ]
    rows = {}
    rowsf = {}
    rowss = {}
    i=0
    for check_lvm in check_lvm_list:
        global_seed = seed
        np.random.seed(global_seed)
        torch.manual_seed(global_seed)
        random.seed(global_seed)
        torch.cuda.manual_seed(global_seed)
        torch.cuda.manual_seed_all(global_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        meta_net = conv_net
        meta_te_results,meta_te_results_fix,meta_te_results_std = compute_result(args,
               meta_net,
               check_lvm,
               image_data=image_data,
               rand_eval=True,
               writer=writer,seed=seed)
        meta_te_results = np.round(meta_te_results, decimals=3)
        meta_te_results_std = np.round(meta_te_results_std, decimals=3)
        meta_te_results = meta_te_results.reshape(-1)
        meta_te_results_std = meta_te_results_std.reshape(-1)
        rows[check_lvm] = []
        rowsf[check_lvm] = []
        rowss[check_lvm] = []
        for j in range(len(meta_te_results)):
           rows[check_lvm].append(round(meta_te_results[j],3))
           rowss[check_lvm].append(round(meta_te_results_std[j],3))
        rowsf[check_lvm].append(meta_te_results_fix[0])
        i += 1
    if args.arch == 'm':
        save_path = f'./final_result/{image_data}/M/{seed}/{writer}'
    else:
        save_path = f'./final_result/{image_data}/C/{seed}/{writer}'
    results_fix = np.array(meta_te_results_fix[0])
    os.makedirs(save_path, exist_ok=True)
    with open(save_path+f'/table_results.txt', 'w', encoding='utf-8') as f:
        for key, value in rows.items():
            f.write(f"{key}: {value}\n")
    with open(save_path+f'/table_results_s.txt', 'w', encoding='utf-8') as f:
        for key, value in rowss.items():
            f.write(f"{key}: {value}\n")
    with open(save_path+f'/fp_results.txt', 'w', encoding='utf-8') as f:
        for key, value in rowsf.items():
            f.write(f"{key}: {value}\n")
    np.savetxt(os.path.join(save_path, 'fix_points.csv'), results_fix)
