import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.nn.functional import one_hot
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import yaml
import time

device = "cuda"

import argparse
parser=argparse.ArgumentParser()
parser.add_argument('--noisy',dest='noisy',default='False',type=str)
parser.add_argument('--dic',dest='dic',default='True',type=str)
parser.add_argument('--dataset',dest='dataset',default='cifar10',type=str)
parser.add_argument('--batch_size', dest='batch_size', default=128, type=int)
parser.add_argument('--epoch_change', dest='epoch_change', default=500, type=int)
parser.add_argument('--epoch', dest='epoch', default=1000, type=int)
parser.add_argument('--runs', dest='runs', default='1_5000', type=str)
parser.add_argument('--train_rho', dest='train_rho', default=0.01, type=float)
parser.add_argument('--dy', dest='dy', default='True', type=str)
parser.add_argument('--ly', dest='ly', default='False', type=str)
parser.add_argument('--dic_error', dest='dic_error', default='True', type=str)
parser.add_argument('--dic_freq', dest='dic_freq', default='False', type=str)
parser.add_argument('--dic_norm', dest='dic_norm', default='False', type=str)
parser.add_argument('--dic_noise', dest='dic_noise', default='False', type=str)
parser.add_argument('--n_dic', dest='n_dic', default=None, type=int)
parser.add_argument('--test_size', dest='test_size', default=50, type=int)
parser.add_argument('--test', dest='test', default='True', type=str)
parser.add_argument('--aug_test', dest='aug_test',default='True', type=str)  
parser.add_argument('--earlystop', dest='earlystop',default='False', type=str)  
parser.add_argument('--joint', dest='joint',default='dy_ly', type=str)
#parser.add_argument('--lr_warmup', dest='lr_warmup', default=10, type=int)
#parser.add_argument('--lr_schedule', dest='lr_schedule', default='False', type=str)
args=parser.parse_args()

dataset=args.dataset  # Dataset, either Cifar10 or Cifar100
batch_size=args.batch_size # Training batchsize
total_epoch=args.epoch # Total training epoch
switch_epoch =args.epoch_change
train_rho=args.train_rho # Imbalance ratio : Min/Max
runs=args.runs
if args.noisy == 'True':
    save_path= f'./results/{dataset}/noisy/{runs}/post_hoc'
else:
    save_path= f'./results/{dataset}/simple/split_no/{runs}/post_hoc'
    #save_path= f'./results/{dataset}/simple/split/{runs}/post_hoc'
    #save_path= f'./results/{dataset}/dyly_dic_warmup/{runs}/post_hoc'
#save_path= f'./results/{dataset}/simple/balance/{runs}/post_hoc'
if dataset=='cifar10':
    num_classes=10
elif dataset=='cifar100':
    num_classes=100
if args.joint == 'ly_dy':
    args.epoch_ly = args.epoch_change
    args.epoch_dy = args.epoch
if args.joint == 'dy_ly':
    args.epoch_dy = args.epoch_change
    args.epoch_ly = args.epoch
num_workers = 2

# load data
train_mdic = sio.loadmat(f'{save_path}/train_logits.mat')
if args.test == 'True':
    if args.test_size == 10000/num_classes or args.test_size > 10000/num_classes:
        val_mdic = sio.loadmat(f'{save_path}/test_logits.mat')
    else:
        if args.aug_test == 'True':
            val_mdic = sio.loadmat(f'{save_path}/aug{args.aug_test}_{args.test_size}_test_logits.mat')
        else: 
            val_mdic = sio.loadmat(f'{save_path}/{args.test_size}_test_logits.mat')
else:
    val_mdic = sio.loadmat(f'{save_path}/val_logits.mat')
test_mdic = sio.loadmat(f'{save_path}/test_logits.mat')
model_mdic = sio.loadmat(f'{save_path}/normB.mat')
model_norm = model_mdic['normB'][0]
if args.noisy == 'True':
    model_r_i = sio.loadmat(f'{save_path}/r_i.mat')
    r_i = model_r_i['r_i'][0]
    print(r_i)


train_label_all = train_mdic['label_all'][0]
train_loss_all = train_mdic['loss_all'][0]
train_acc_all = train_mdic['acc_all'][0]
train_logits_all = train_mdic['logits_all']
num_class_i = np.zeros([num_classes])
if args.noisy == 'True':
    for i in range(np.shape(train_label_all)[0]):
        num_class_i[train_label_all[i]] = num_class_i[train_label_all[i]]+1
    print('train class num',num_class_i)

val_label_all = val_mdic['label_all'][0]
val_loss_all = val_mdic['loss_all'][0]
val_acc_all = val_mdic['acc_all'][0]
val_logits_all = val_mdic['logits_all']
if args.noisy == 'True':
    print('val size',np.shape(val_label_all))
    num_class_i = np.zeros([num_classes])
    for i in range(np.shape(val_label_all)[0]):
        num_class_i[val_label_all[i]] = num_class_i[val_label_all[i]]+1
    print('val class num',num_class_i)
val_dataset = TensorDataset(torch.Tensor(val_label_all).to(dtype=torch.long),torch.Tensor(val_logits_all))
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=False, pin_memory=True)

test_label_all = test_mdic['label_all'][0]
test_loss_all = test_mdic['loss_all'][0]
test_acc_all = test_mdic['acc_all'][0]
test_logits_all = test_mdic['logits_all']
if args.noisy == 'True':
    print('test size',np.shape(test_label_all))
    num_class_i = np.zeros([num_classes])
    for i in range(np.shape(test_label_all)[0]):
        num_class_i[test_label_all[i]] = num_class_i[test_label_all[i]]+1
    print('test class num',num_class_i)
test_dataset = TensorDataset(torch.Tensor(test_label_all).to(dtype=torch.long),torch.Tensor(test_logits_all))
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=False, pin_memory=True)

def logit_adjust(logits,params):
    #assert(len(params)==2)
    #print(np.shape(logits))
    #print(np.shape(params))
    dy=params[0]
    ly=params[1]
    if args.dy == 'False':
        x=torch.tensor(logits)-ly
    else:
        x=logits/torch.sigmoid(dy)-ly
    return x

def eval_final(logits,label,logit_adjust,params,class_wise,num_classes):
    total=np.shape(label)[0]
    class_correct=np.zeros(num_classes,dtype=float)
    class_total=np.zeros(num_classes,dtype=float)
    if not logit_adjust is None:
        logits=logit_adjust(torch.Tensor(logits).cuda(),params)
        logits = logits.cpu().detach().numpy()
    preds = np.argmax(logits,axis = 1)
    correct =np.sum(preds==label)
    if class_wise:
        for i in range(num_classes):
            indexes=np.where(label==i)[0]
            class_total[i] =np.shape(indexes)[0]
            acc = preds==label
            class_correct[i] =np.sum(acc[indexes])
    return class_correct/class_total,correct/total
@torch.no_grad()
def compute_loss(data,target,params,logit_adjust,num_classes):
    if not logit_adjust is None:
        data_adj = logit_adjust(torch.Tensor(data).cuda(),params)
    loss_class = criterion(data_adj,torch.Tensor(target).to(dtype=torch.long).cuda())
    target_oh = one_hot(torch.Tensor(target).to(dtype=torch.long).cuda(),num_classes = num_classes)
    sum_classes = torch.sum(target_oh,dim=0)
    class_non0_index = torch.nonzero(sum_classes, as_tuple=True)[0]
    sum_classes_non0 = torch.index_select(sum_classes, 0, class_non0_index)
    loss_class_oh = torch.transpose(loss_class*torch.transpose(target_oh,0,1),0,1)
    loss = torch.mean(torch.div(torch.index_select(torch.sum(loss_class_oh, dim=0),0,class_non0_index),sum_classes_non0))
    return loss.cpu().item()

class_freq=[]
val_rho = 0.01
val_mu=val_rho**(1./(num_classes-1))
for i in range(num_classes):
    class_freq.append(val_mu**i)

'''
if num_classes == 100:
    for i in range(num_classes):
        if i%10 == 0:
            class_freq.append(val_mu**i)
else:
    for i in range(num_classes):
        class_freq.append(val_mu**i)
'''
#class_freq = torch.Tensor(class_freq).cuda()

# no posthoc
class_wise_acc,acc = eval_final(test_logits_all,test_label_all,logit_adjust = None,params = None,class_wise = True,num_classes = num_classes)
print('test_class_wise_acc',class_wise_acc)
print('test acc',acc)
print('test min acc',np.min(class_wise_acc))
print('test std',np.std(class_wise_acc))
print('test few',np.mean(class_wise_acc[-int(0.2*num_classes):-1]))
print('test worst',np.mean(sorted(class_wise_acc)[0:int(0.2*num_classes)]))
class_wise_acc_dic = class_wise_acc
# no posthoc
class_wise_acc,acc = eval_final(val_logits_all,val_label_all,logit_adjust = None,params = None,class_wise = True,num_classes = num_classes)
print('val class_wise_acc',class_wise_acc)
print('val acc',acc)
print('val mean acc',np.mean(class_wise_acc))




def make_model(num_dic):
    # simple training(d and l)
    dy_w=torch.zeros([num_dic],dtype=torch.float32,device=device)
    ly_w=torch.zeros([num_dic],dtype=torch.float32,device=device)
    dy_w[-1] = 1
    if args.dy == 'True':
        dy_w.requires_grad=True
    else:
        dy_w.requires_grad=False

    if args.ly == 'True':
        ly_w.requires_grad=True
    else:
        ly_w.requires_grad=False


    criterion = nn.CrossEntropyLoss()
    optimizer_dy = optim.SGD([dy_w],lr=0.001,momentum=0.9, weight_decay=1e-4)
    optimizer_ly = optim.SGD([ly_w],lr=0.001,momentum=0.9, weight_decay=1e-4)

    scheduler_dy=optim.lr_scheduler.MultiStepLR(optimizer_dy,milestones=[args.epoch_dy],gamma=0.5)
    scheduler_ly=optim.lr_scheduler.MultiStepLR(optimizer_ly,milestones=[args.epoch_ly],gamma=0.5)
    #scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep)
    return dy_w,ly_w,criterion,optimizer_dy,optimizer_ly,scheduler_dy,scheduler_ly


def func_dic(dic,x):
    x1 = x/np.sum(x)
    dic.append(np.log(x1))
    dic.append((np.log(x1)-np.min(np.log(x1)))**2)
    dic.append((np.log(x1)-np.min(np.log(x1)))**0.5)
    x2 = np.divide(x,np.ones_like(x)*np.max(x))
    dic.append(x2**0.15)
    dic.append(x2**0.3)
    return dic
def dic_PCA(dic,n_dic):
    dic_svd = []
    x = StandardScaler().fit_transform(np.array(dic).T)
    if n_dic is not None:
        pca = PCA(n_components=n_dic)
        x = pca.fit_transform(x)
        svd_dic_values_list = x.T
        for i in range(n_dic):
            dic_svd.append(svd_dic_values_list[i, :])
    else:
        dic_svd = x.T
    dic_svd = np.array(dic_svd)
    return dic_svd


def make_dic(class_freq,class_wise_acc_dic,class_weight_norm,class_noise_dic,num_classes):
    params_dic_origin = []

    if not class_freq is None:
        class_freq = np.array(class_freq)
        params_dic_origin=func_dic(params_dic_origin,class_freq)
    if not class_wise_acc_dic is None:
        class_diff = np.divide(np.ones_like(class_wise_acc_dic)*np.min(np.ones_like(class_wise_acc_dic)-class_wise_acc_dic),np.ones_like(class_wise_acc_dic)-class_wise_acc_dic)
        params_dic_origin=func_dic(params_dic_origin,class_diff)
    if not class_weight_norm is None:
        class_norm = class_weight_norm/np.max(class_weight_norm)
        params_dic_origin=func_dic(params_dic_origin,class_norm)
    if not class_noise_dic is None:
        class_noise = np.divide(np.ones_like(class_noise_dic)-class_noise_dic,np.ones_like(class_noise_dic)-class_noise_dic)
        params_dic_origin=func_dic(params_dic_origin,class_noise)       
    params_dic_origin = dic_PCA(params_dic_origin,args.n_dic)
    #for i in range(np.shape(params_dic_origin)[0]):
    #    params_dic_origin[i] = norm_dic(params_dic_origin[i])
    params_dic_origin = np.vstack((params_dic_origin,np.ones([num_classes])))
    return params_dic_origin

if args.test == 'True':
    w_val = torch.ones([num_classes],dtype=torch.float32, device=device)
else:
    w_val=np.sum(class_freq)/class_freq
    w_val=w_val/np.linalg.norm(w_val)
    w_val=torch.tensor(w_val,dtype=torch.float32, device=device)
w_val.requires_grad=False

args.wy = w_val
#def dic_svd(params_dic_origin):
#print('class_freq',class_freq)
#params_dic_origin = make_dic(class_freq,class_wise_acc_dic,num_classes)
if args.dic_freq == 'True':
    class_freq_dic = class_freq
else:
    class_freq_dic = None
if args.dic_error == 'True':
    class_error_dic = class_wise_acc_dic
else:
    class_error_dic = None
if args.dic_norm == 'True':
    class_norm_dic = model_norm
else:
    class_norm_dic = None
if args.dic_noise == 'True':
    #class_noise_dic = args.r_class_noise
    class_noise_dic = r_i
else:
    class_noise_dic = None

params_dic_origin = make_dic(class_freq_dic,class_error_dic,class_norm_dic,class_noise_dic,num_classes)
print(np.shape(params_dic_origin))
params_dic = params_dic_origin
num_dic = np.shape(params_dic)[0]
dy_w,ly_w,criterion,optimizer_dy,optimizer_ly,scheduler_dy,scheduler_ly = make_model(num_dic = num_dic)

import os
result_path = f'{save_path}/result/2_joint{args.joint}_earlystop{args.earlystop}_sigmoid_test{args.test_size}_freq{args.dic_freq}_error{args.dic_error}_norm{args.dic_norm}_noise{args.dic_noise}_dy{args.dy}_ly{args.ly}'
if not os.path.exists(result_path):
    os.makedirs(result_path)

plt.figure()
for ii in range(np.shape(params_dic_origin)[0]):
    plt.scatter(range(num_classes), params_dic_origin[ii])
plt.grid()
plt.savefig(f'{result_path}/try_dic.png')
plt.show()

def train(params_dic,dy_w,ly_w,criterion,optimizer_1,optimizer_2,scheduler_1,scheduler_2,total_epoch,switch_epoch,num_classes):
    test_class_wise_acc_all = []
    test_mean_acc_all = []
    val_class_wise_acc_all = []
    val_mean_acc_all = []
    test_worst_acc_all = []
    val_worst_acc_all = []
    val_loss_all = []
    test_loss_all = []
    dy_w_all = []
    ly_w_all = []
    dy_all = []
    ly_all = []
    best_val_acc = 0
    for epoch in range(total_epoch):
        loss_all = 0
        num_all = 0
        if epoch <switch_epoch:
            optimizer = optimizer_1
            scheduler = scheduler_1
            if args.joint == 'ly_dy' and args.ly == 'False':
                continue
            if args.joint == 'dy_ly' and args.dy == 'False':
                continue
        else:
            optimizer = optimizer_2
            scheduler = scheduler_2    
            if args.joint == 'ly_dy' and args.dy == 'False':
                continue
            if args.joint == 'dy_ly' and args.ly == 'False':
                continue 
        if args.earlystop == 'True' and epoch == switch_epoch:
            print('previous_ly_w',ly_w)
            if args.joint == 'ly_dy':
                ly_w = best_ly_w.clone().detach()
            else:
                dy_w = best_dy_w.clone().detach()
            print('best_ly_w',ly_w)
        for cur_iter,(val_label_i,val_logits_i) in enumerate(val_loader):
            data,target = val_logits_i.cuda(),val_label_i.cuda(non_blocking=True)
            dy=torch.matmul(torch.tensor(np.array(params_dic).T,dtype=torch.float32,device=device),dy_w)
            ly=torch.matmul(torch.tensor(np.array(params_dic).T,dtype=torch.float32,device=device),ly_w)
            #params = [dy,ly]

            data_adj = logit_adjust(data,[dy,ly])
            loss = criterion(data_adj, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_all+=loss.item()
            num_all+=np.shape(target)[0]
        #model.eval()
        #params_1 = model(torch.Tensor(params_0).cuda())
        #params = torch.reshape(params_1,[2,-1])
        #print('lr: ',optimizer.param_groups[0]['lr'])
        dy_w_all.append(dy_w.cpu().detach().numpy())
        ly_w_all.append(ly_w.cpu().detach().numpy())
        dy_all.append(dy.cpu().detach().numpy())
        ly_all.append(ly.cpu().detach().numpy())
        val_loss = compute_loss(val_logits_all,val_label_all,logit_adjust = logit_adjust,num_classes = num_classes,params = [dy,ly])
        test_loss = compute_loss(test_logits_all,test_label_all,logit_adjust = logit_adjust,num_classes = num_classes,params = [dy,ly])
        scheduler.step()
        class_wise_acc,acc = eval_final(test_logits_all,test_label_all,logit_adjust = logit_adjust,params = [dy,ly],class_wise = True,num_classes = num_classes)
        val_class_wise_acc,val_acc = eval_final(val_logits_all,val_label_all,logit_adjust = logit_adjust,params = [dy,ly],class_wise = True,num_classes = num_classes)
        if val_acc>best_val_acc:
            best_test_acc = acc
            best_val_acc = val_acc
            best_epoch = epoch
            best_dy = dy.clone().detach()
            best_ly = ly.clone().detach()
            best_dy_w = dy_w.clone().detach()
            best_ly_w = ly_w.clone().detach()
        #print('class_wise_acc',class_wise_acc)
        #print('acc',acc)
        #print('loss',loss_all/num_all)
        val_class_wise_acc_all.append(val_class_wise_acc)
        val_mean_acc_all.append(val_acc)
        test_class_wise_acc_all.append(class_wise_acc)
        test_mean_acc_all.append(acc)
        test_worst_acc_all.append(np.min(class_wise_acc))
        val_worst_acc_all.append(np.min(val_class_wise_acc))
        val_loss_all.append(val_loss)
        test_loss_all.append(test_loss)
    return best_test_acc,best_val_acc,best_epoch,best_dy,best_ly,test_class_wise_acc_all,test_mean_acc_all,val_class_wise_acc_all,val_mean_acc_all,test_worst_acc_all,val_worst_acc_all,val_loss_all,test_loss_all,dy_w,ly_w,dy_w_all,ly_w_all,dy_all,ly_all
    model,criterion,optimizer,scheduler = make_model(num_classes = 10)
#params_0 = np.ones((2*10))
#params_0 = np.hstack((np.ones_like(class_freq)-class_freq,np.ones_like(class_wise_acc_10)-class_wise_acc_10))
#params_0 = np.hstack((np.ones_like(class_freq)-class_freq,np.ones_like(class_wise_acc_10)))
#params_0 = np.hstack((np.ones_like(class_wise_acc_10)-class_wise_acc_10,np.ones_like(class_freq)))
print('train')
if args.joint == 'ly_dy':
    optimizer_1 = optimizer_ly
    optimizer_2 = optimizer_dy
    scheduler_1 = scheduler_ly
    scheduler_2 = scheduler_dy   
if args.joint == 'dy_ly':
    optimizer_1 = optimizer_dy
    optimizer_2 = optimizer_ly
    scheduler_1 = scheduler_dy
    scheduler_2 = scheduler_ly   
print('train')
start_time = time.time()
best_test_acc,best_val_acc,best_epoch,best_dy,best_ly,test_class_wise_acc_all,test_mean_acc_all,val_class_wise_acc_all,val_mean_acc_all,test_worst_acc_all,val_worst_acc_all,val_loss_all,test_loss_all,dy_w,ly_w,dy_w_all,ly_w_all,dy_all,ly_all = train(params_dic,dy_w,ly_w,criterion,optimizer_1,optimizer_2,scheduler_1,scheduler_2,total_epoch,switch_epoch,num_classes=num_classes)
end_time = time.time()
if args.dy == 'False':
    dy_final = np.ones([num_classes]).tolist()
else:
    if args.earlystop == 'True':
        dy_final = best_dy).tolist()
    else:
        dy_final = torch.tensor(np.array(dy_all)[-1,:]).tolist()
if args.earlystop == 'True':
    ly_final = best_ly.tolist()
else:
    ly_final = (np.array(ly_all)[-1,:]).tolist()

#dy_final = torch.sigmoid(torch.tensor(np.array(dy_all)[-1,:])).tolist()
result_file = {'dy' : dy_final,'ly':ly_final,'w_train':(np.ones([num_classes])).tolist()}
#result_file = [{'dy' : (best_dy).tolist()},{'ly':(best_ly).tolist()},{'w_train',(np.ones([num_classes])).tolist()},{'w_val',(np.ones([num_classes])).tolist()}]
with open(f'{result_path}/result.yaml', mode='w') as file:
    documents = yaml.dump(result_file, file)
plt.figure()
for ii in range(np.shape(dy_w_all)[1]):
    plt.plot(np.array(dy_w_all)[:,ii])
    plt.scatter([best_epoch],np.array(dy_w_all)[best_epoch,ii])
plt.grid()
plt.savefig(f'{result_path}/try_dic_dy_w.png')
plt.show()
plt.figure()
for ii in range(np.shape(ly_w_all)[1]):
    plt.plot(np.array(ly_w_all)[:,ii])
    plt.scatter([best_epoch],np.array(ly_w_all)[best_epoch,ii])
plt.grid()
plt.savefig(f'{result_path}/try_dic_ly_w.png')
plt.show()

plt.figure()
for ii in range(np.shape(dy_all)[1]):
    plt.plot(np.array(dy_all)[:,ii],label=str(ii))
    plt.scatter([best_epoch],np.array(dy_all)[best_epoch,ii])
plt.grid()
plt.legend()
plt.savefig(f'{result_path}/try_dic_dy.png')
plt.show()
plt.figure()
for ii in range(np.shape(ly_all)[1]):
    plt.plot(np.array(ly_all)[:,ii],label=str(ii))
    plt.scatter([best_epoch],np.array(ly_all)[best_epoch,ii])
plt.legend()
plt.grid()
plt.savefig(f'{result_path}/try_dic_ly.png')
plt.show()
plt.figure()
for ii in range(np.shape(test_class_wise_acc_all)[1]):
    plt.plot(np.array(test_class_wise_acc_all)[:,ii],label=str(ii))
    plt.scatter([best_epoch],np.array(test_class_wise_acc_all)[best_epoch,ii])
plt.legend()
plt.grid()
plt.savefig(f'{result_path}/test_class_wise_acc_all.png')
plt.show()


dy=torch.matmul(torch.tensor(np.array(params_dic).T,dtype=torch.float32,device=device),dy_w)
ly=torch.matmul(torch.tensor(np.array(params_dic).T,dtype=torch.float32,device=device),ly_w)
params = [dy,ly]
class_wise_acc,acc = eval_final(test_logits_all,test_label_all,logit_adjust = logit_adjust,params = params,class_wise = True,num_classes = num_classes)
print('dy',dy)
print('ly',ly)
print('test_class_wise_acc',class_wise_acc)
print('test_acc',acc)
print('test_min acc',np.min(class_wise_acc))
val_class_wise_acc,val_acc = eval_final(val_logits_all,val_label_all,logit_adjust = logit_adjust,params = params,class_wise = True,num_classes = num_classes)
#print('val_class_wise_acc',class_wise_acc)
print('val_acc',val_acc)
print('val_min acc',np.min(val_class_wise_acc))
plt.figure()
plt.plot(val_loss_all,label = 'val')
plt.plot(test_loss_all,label = 'test')
plt.scatter([best_epoch],np.array(val_loss_all)[best_epoch])
plt.scatter([best_epoch],np.array(test_loss_all)[best_epoch])
plt.legend()
plt.grid()
plt.savefig(f'{result_path}/try_loss.png')
plt.show()

plt.figure()
plt.plot(test_mean_acc_all,label='test mean')
plt.scatter([best_epoch],np.array(test_mean_acc_all)[best_epoch])
plt.plot(test_worst_acc_all,label='test min')
plt.scatter([best_epoch],np.array(test_worst_acc_all)[best_epoch])
plt.scatter([best_epoch],np.array(val_mean_acc_all)[best_epoch])
plt.plot(val_mean_acc_all,label='val mean')
plt.plot(val_worst_acc_all,label='val min')
plt.legend()
plt.grid()
plt.savefig(f'{result_path}/try_acc.png')
plt.show()

plt.figure()
plt.scatter(range(num_classes), torch.sigmoid(params[0]).cpu().detach().numpy(),label = 'dy')
plt.scatter(range(num_classes),params[1].cpu().detach().numpy(),label = 'ly')
plt.legend()
plt.grid()
plt.savefig(f'{result_path}/try_params.png')
plt.show()

plt.figure()
plt.scatter(range(num_classes), dy_final,label = 'dy')
plt.scatter(range(num_classes),ly_final,label = 'ly')
plt.legend()
plt.grid()
plt.savefig(f'{result_path}/try_params_result.png')
plt.show()
from scipy.io import savemat
#mdic = {'dy_w':dy_w,'ly_w':ly_w,'dy':dy,'ly':ly,'test_class_wise_acc':class_wise_acc,'test_acc':np.mean(class_wise_acc),'val_class_wise_acc':val_class_wise_acc,'val_acc':np.mean(val_class_wise_acc),'test_class_wise_acc_all':test_class_wise_acc_all}
#savemat(f'{save_path}/joint{args.joint}_earlystop{args.earlystop}_sigmoid_test{args.test_size}_freq{args.dic_freq}_error{args.dic_error}_norm{args.dic_norm}_noise{args.dic_noise}_dy{args.dy}_ly{args.ly}/post_hoc_result.mat',mdic)
print(best_test_acc,1-best_test_acc,best_val_acc,best_epoch)
logfile=open(f'{result_path}/logs.txt',mode='w')
#text=f'dy_w:{best_dy_w}\nly_w:{best_ly_w}\ndy:{best_dy}\nly:{best_ly}\ntest_class_wise_acc:{class_wise_acc.tolist()}\ntest_acc:{np.mean(class_wise_acc)}\nval_acc:{val_acc}\n'
text=f'dy:{dy_final}\nly:{ly_final}\ntest_class_wise_acc:{class_wise_acc.tolist()}\ntest_acc:{np.mean(class_wise_acc)}\nbest_test_acc:{best_test_acc}\nbest_epoch:{best_epoch}\nbest_val_acc:{best_val_acc}\ntest_error:{1-np.mean(best_test_acc)}\ntime:{end_time-start_time}\ntest_std:{np.std(class_wise_acc).tolist()}\nfew_acc:{np.mean(class_wise_acc[-int(0.2*num_classes):-1]).tolist()}\nworst_acc:{np.mean(sorted(class_wise_acc)[0:int(0.2*num_classes)]).tolist()}'
logfile.write(text)

# plt.figure()
# norm_error = (np.ones_like(test_mean_acc_all)-test_mean_acc_all)/(np.ones_like(test_mean_acc_all)*np.max((np.ones_like(test_mean_acc_all)-test_mean_acc_all)))
# plt.plot(norm_error,linewidth=5,label='Test balanced error')
# plt.plot(test_loss_all/(np.ones_like(test_loss_all)*np.max(test_loss_all)),linewidth=5,label = 'Test balanced CE')
# plt.legend(fontsize=20)
# plt.grid()
# plt.tick_params(labelsize=20)
# plt.savefig(f'{result_path}/error_loss.png')
# plt.show()

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(np.ones_like(test_mean_acc_all)-test_mean_acc_all,linewidth=6,label='Test balanced error',alpha=0.7)
ax2 = ax.twinx()
ax2.plot(test_loss_all,'C1',linewidth=5,label='Test balanced CE',alpha=0.7)
ax2.grid()
ax.tick_params(labelsize=20)
ax2.tick_params(labelsize=20)
fig.legend(loc=1, bbox_to_anchor=(1,1), bbox_transform=ax.transAxes,fontsize=20)
ax.set_ylabel('Error')
ax2.set_ylabel('Loss')
ax.xaxis.LABELPAD = 8
ax2.xaxis.LABELPAD = 8
plt.savefig(f'{result_path}/error_loss.png')
plt.show()


index_class0 = np.where(test_label_all==0)[0]
index_class9 = np.where(test_label_all==(num_classes-1))[0]
logits_majority_class0_origin = test_logits_all[index_class0,0]
logits_minority_class0_origin = test_logits_all[index_class0,-1]
logits_majority_class9_origin = test_logits_all[index_class9,0]
logits_minority_class9_origin = test_logits_all[index_class9,-1]

torch.tensor(np.array(dy_w_all[-1]),dtype=torch.float32,device=device)
dy_new = torch.matmul(torch.tensor(np.array(params_dic).T,dtype=torch.float32,device=device),torch.tensor(np.array(dy_w_all[-1]),dtype=torch.float32,device=device))
ly_new = torch.matmul(torch.tensor(np.array(params_dic).T,dtype=torch.float32,device=device),torch.tensor(np.array(ly_w_all[-1]),dtype=torch.float32,device=device))
data_adj = logit_adjust(test_logits_all,[dy_new.cpu(),ly_new.cpu()])


logits_majority_class0 = data_adj[index_class0,0]
logits_minority_class0 = data_adj[index_class0,-1]
logits_majority_class9 = data_adj[index_class9,0]
logits_minority_class9 = data_adj[index_class9,-1]

fig, ax = plt.subplots()
ax.scatter(logits_majority_class0_origin,logits_minority_class0_origin,alpha=0.7,label='class0, pretrain')
ax.scatter(logits_majority_class9_origin,logits_minority_class9_origin,alpha=0.7,label='class9, pretrain')
ax.scatter(logits_majority_class0,logits_minority_class0,alpha=0.7,label='class0, post-hoc')
ax.scatter(logits_majority_class9,logits_minority_class9,alpha=0.7,label='class9, post-hoc')
ax.set_xlim([-25, 80])
ax.set_ylim([-25, 80])
lims = [
np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
]
ax.plot(lims, lims)
ax.grid()
ax.legend(fontsize=20)
ax.tick_params(labelsize=20)
fig.savefig(f'{result_path}/logits_all.png')
plt.show()

plt.figure()
norm_error = (np.ones_like(test_mean_acc_all)-test_mean_acc_all)/(np.ones_like(test_mean_acc_all)-np.min((np.ones_like(test_mean_acc_all)-test_mean_acc_all)))
plt.plot(norm_error,linewidth=5,label='Test balanced error')
plt.plot(test_loss_all-(np.ones_like(test_loss_all)*np.min(test_loss_all)),linewidth=5,label = 'Test balanced CE')
plt.legend(fontsize=20)
plt.grid()
plt.tick_params(labelsize=20)
plt.savefig(f'{result_path}/error_loss_2.png')
plt.show()

fig, ax = plt.subplots()
ax.scatter(logits_majority_class0,logits_minority_class0,alpha=0.7,label='class0')
ax.scatter(logits_majority_class9,logits_minority_class9,alpha=0.7,label='class9')
ax.set_xlim([-25, 80])
ax.set_ylim([-25, 80])
lims = [
np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
]
ax.plot(lims, lims)
ax.grid()
ax.legend(fontsize=20)
ax.tick_params(labelsize=20)
fig.savefig(f'{result_path}/logits_all.png')
plt.show()

fig, ax = plt.subplots()
ax.scatter(logits_majority_class0_origin,logits_minority_class0_origin,alpha=0.7,label='class0')
ax.scatter(logits_majority_class9_origin,logits_minority_class9_origin,alpha=0.7,label='class9')
ax.set_xlim([-25, 80])
ax.set_ylim([-25, 80])
lims = [
np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
]
ax.plot(lims, lims)
ax.grid()
ax.legend(fontsize=20)
ax.tick_params(labelsize=20)
fig.savefig(f'{result_path}/logits_original.png')
plt.show()

#python train_posthoc_order1_v4_joint.py --noisy True --test_size 50 --runs 1_5000_imbalance --dic_freq True 
#python train_posthoc_order1_v4_joint.py --dy False --ly True --joint ly_dy --dic_error False --dic_freq True