# -*- coding: utf-8 -*-
"""post_hoc.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1x4ee3UQ4-EX_3V9m_U0ApdKPVIGSpdGb
"""

path = ""

import numpy as np
import matplotlib.pyplot as plt
import torch
import csv
import numpy as np
import random
import torch
from PIL import Image
from torch.utils.data import DataLoader,TensorDataset
#from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine,RandomResizedCrop,CenterCrop
#from torchvision.transforms import ToTensor, Normalize,transforms
from PIL.Image import BICUBIC
import scipy.io as sio

batch_size = 128
num_workers=4

import torch
import torch.nn as nn
import torch.nn.functional as F

arr = sio.loadmat('saved_test_feature.mat')
print("load .npy done")
test_features = arr['features']
test_label = arr['label']
test_g = arr['g']
arr = sio.loadmat('saved_train_feature.mat')
print("load .npy done")
train_features = arr['features']
train_label = arr['label']
train_g = arr['g']
'''
arr = sio.loadmat('120_preds-on_validation.mat')
print("load .npy done")
validation_features = arr['features']
validation_label = arr['label']
validation_g = arr['g']
'''
arr = sio.loadmat('saved_validation_feature.mat')
#print(arr[0])
print("load .npy done")
validation_test_features = arr['features']
validation_test_label = arr['label']
validation_test_g = arr['g']

train_features = np.vstack((train_features,validation_features))
train_label = np.vstack((train_label,validation_label))
train_g = np.vstack((train_g,validation_g))

train_tensor_x = torch.Tensor(train_features)
train_tensor_y = torch.Tensor(train_label)
train_tensor_g = torch.Tensor(train_g)
val_tensor_x = torch.Tensor(validation_features)
val_tensor_y = torch.Tensor(validation_label)
val_tensor_g = torch.Tensor(validation_g)
val_test_tensor_x = torch.Tensor(validation_test_features)
val_test_tensor_y = torch.Tensor(validation_test_label)
val_test_tensor_g = torch.Tensor(validation_test_g)
test_tensor_x = torch.Tensor(test_features)
test_tensor_y = torch.Tensor(test_label)
test_tensor_g = torch.Tensor(test_g)

train_my_dataset = TensorDataset(train_tensor_x,train_tensor_y,train_tensor_g)
train_my_dataloader = DataLoader(train_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
test_my_dataset = TensorDataset(test_tensor_x,test_tensor_y,test_tensor_g)
test_my_dataloader = DataLoader(test_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
val_my_dataset = TensorDataset(val_tensor_x,val_tensor_y,val_tensor_g)
val_my_dataloader = DataLoader(val_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)
val_test_my_dataset = TensorDataset(val_test_tensor_x,val_test_tensor_y,val_test_tensor_g)
val_test_my_dataloader = DataLoader(val_test_my_dataset,batch_size=batch_size, num_workers=num_workers,shuffle=True, drop_last=False, pin_memory=True)

D_in = np.shape(test_features)[1]

num_classes = 2
num_epochs = 500

class CE_loss(nn.Module):
  def __init__(self):
    super().__init__()
  def forward(self,logits,label,group):
    return F.cross_entropy(logits,label)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, num_classes),
)

def eval_per_class(data_loader, model, text,flag=0):
    model.eval()
    correct=0.
    total=0.
    loss=0.
    class_group_correct = list(0. for i in range(4))
    class_group_total = list(0. for i in range(4))
    accuracy_4 = []
    classes = ('00', '01', '10', '11')
    for cur_iter, (data, label,group) in enumerate(data_loader):
        label = label.long()
        label = label[:,0]
        group = group[:,0]
        #data, label,group = data.cuda(), label[:,0].cuda(non_blocking=True),group[:,0].cuda(non_blocking=True)
        logits = model(data)
        preds = logits.data.max(1)[1]
        c = (label == preds).squeeze()
        #print(logits,preds, targets==preds)
        mb_size = data.size(0)
        # if not dy is None:
        #     print(my_cross_entropy(logits,labels,dy,ly))
        # if 'train' in text:
        #     loss += loss_fun(logits, labels,dy,ly ).item()*mb_size
        # else:
        #     loss += loss_fun(logits, labels).item()*mb_size
        total+=mb_size
        if mb_size>=1:
          for i in range(int(mb_size)):
            #label = preds[i].item()
            label_i = label[i].item()
            group_i = group[i].item()
            if label_i == 0 and group_i == 0:
              class_4 = 0
            if label_i == 0 and group_i == 1:
              class_4 = 1              
            if label_i == 1 and group_i == 0:
              class_4 = 2
            if label_i == 1 and group_i == 1:
              class_4 = 3             
            class_group_correct[class_4] += c.cpu().numpy()[i]
            class_group_total[class_4] += 1
        correct+=preds.eq(label.data.view_as(preds)).sum().item()
    if flag ==0:
      for i in range(4):
        if class_group_total[i] != 0:
          print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * class_group_correct[i] / class_group_total[i])
        else:
          print('No image')
    else:
      for i in range(4):
        if class_group_total[i] != 0:
          #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * float(class_group_correct[i]) / float(class_group_total[i]))
        else:
          print('No image')      
        '''
    for i in range(4):
      if class_group_total[i] != 0:
        #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
        print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
      else:
        print('No image')
        '''
    
    print(f'{text}:ACC = {float(correct)/float(total)*100.}')
    print(f'{text}:balance ACC = {np.mean(accuracy_4)}')

    return float(correct)/float(total)*100.,np.mean(accuracy_4),accuracy_4

    #return f'{text}: Epoch {cur_epoch} :  Loss = {loss/total}   ACC = {correct/total*100.}',loss/total,correct/total*100.

val_acc_all=[]
test_acc_all=[]

num_epochs = 100

seed = 9
  torch.manual_seed(seed)
  model = torch.nn.Linear(D_in, num_classes)
  print(model.weight)
  print(model.bias)
  #model.to('cuda')

  lr = 0.01
  #criterion_vector = nn.CrossEntropyLoss()
  criterion = CE_loss()
  optimizer = torch.optim.Adam(params=model.parameters(),lr=lr)
  #train_lr_scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[60,100,120],gamma=0.5)
  losses = []
  acc_val_all = []
  acc_test_all = []
  acc_val_test_all = []
  val_balanced_acc_all = []
  test_balanced_acc_all = []
  acc_train_all = []
  train_balanced_acc_all = []
  val_test_balanced_acc_all = []
  loss_all = 0
  print(la)
  for epoch in range(num_epochs):
    i = 0
    loss_all = 0
    for cur_iter, (data, label,group) in enumerate(train_my_dataloader):
      label = label.long()
      group = group.long()
      label = label[:,0]
      group = group[:,0] 
      #data, label,group = data.cuda(non_blocking=True), label[:,0].cuda(non_blocking=True), group[:,0].cuda(non_blocking=True)
      logist = model(data)
      loss = criterion(logist,label,group)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      loss_all = loss_all+loss.item()
      i = i+1
    if (epoch+5) % 1 == 0:
      print(i)
      print ("Epoch [{}/{}], Loss: {:.4f}".format(epoch+1, num_epochs,loss_all/i))
      losses.append(loss_all/i)
      print('train')
      train_acc,train_balanced_acc,_=eval_per_class(train_my_dataloader,model,' train_dataset',flag=1)
      print('val')
      val_acc,val_balanced_acc,_=eval_per_class(val_my_dataloader,model,' val_dataset',flag=1)
      print('val(test)')
      val_test_acc,val_test_balanced_acc,_=eval_per_class(val_test_my_dataloader,model,' val_test_dataset',flag=1)
      print('test')
      test_acc,test_balanced_acc,_=eval_per_class(test_my_dataloader,model,' test_dataset',flag=1)
      acc_val_all.append(val_acc)
      acc_val_test_all.append(val_test_acc)
      acc_test_all.append(test_acc)
      acc_train_all.append(train_acc)
      val_balanced_acc_all.append(val_balanced_acc)
      val_test_balanced_acc_all.append(val_test_balanced_acc)   
      train_balanced_acc_all.append(train_balanced_acc)
    #train_lr_scheduler.step()
  plt.figure()
  plt.plot(losses,label='losses')
  plt.show()
  plt.figure()
  plt.plot(acc_val_all,label='acc_train_all')
  plt.plot(acc_val_all,label='acc_val_all')
  plt.plot(acc_val_test_all,label='acc_val_all')
  plt.plot(acc_test_all,label = 'acc_test_all')
  plt.plot(train_balanced_acc_all,label = 'train_balanced_acc_all')
  plt.plot(val_balanced_acc_all,label = 'val_balanced_acc_all')
  plt.plot(val_test_balanced_acc_all,label = 'val_balanced_acc_all')
  #plt.plot(test_balanced_acc_all,label = 'test_balanced_acc_all')
  plt.legend()
  plt.show()
  val_acc,val_balanced_acc,val_acc=eval_per_class(val_test_my_dataloader,model,' val_dataset')
  test_acc,test_balanced_acc,test_acc=eval_per_class(test_my_dataloader,model,' test_dataset')
  return val_acc,test_acc,model

class Vector(nn.Module):
    def __init__(self):
      super(Vector, self).__init__()
      self.vector1 = nn.Parameter(torch.ones(2) * 1)
      self.vector2 = nn.Parameter(torch.zeros(2) * 1)
    def vector_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        vector1 = self.vector1
        vector2 = self.vector2
        return logits*vector1+vector2
    def forward(self, x):
        out1 = self.vector_scale(x)
        return out1
model_vector = Vector().to()

def eval_per_class(data_loader, model, text,flag=0):
    model.eval()
    correct=0.
    total=0.
    loss=0.
    class_group_correct = list(0. for i in range(4))
    class_group_total = list(0. for i in range(4))
    accuracy_4 = []
    classes = ('00', '01', '10', '11')
    for cur_iter, (data, label,group) in enumerate(data_loader):
        label = label.long()
        label = label[:,0]
        group = group[:,0]
        #data, label,group = data.cuda(), label[:,0].cuda(non_blocking=True),group[:,0].cuda(non_blocking=True)
        logits = model(data)
        preds = logits.data.max(1)[1]
        c = (label == preds).squeeze()
        #print(logits,preds, targets==preds)
        mb_size = data.size(0)
        # if not dy is None:
        #     print(my_cross_entropy(logits,labels,dy,ly))
        # if 'train' in text:
        #     loss += loss_fun(logits, labels,dy,ly ).item()*mb_size
        # else:
        #     loss += loss_fun(logits, labels).item()*mb_size
        total+=mb_size
        if mb_size>=1:
          for i in range(int(mb_size)):
            #label = preds[i].item()
            label_i = label[i].item()
            group_i = group[i].item()
            if label_i == 0 and group_i == 0:
              class_4 = 0
            if label_i == 0 and group_i == 1:
              class_4 = 1              
            if label_i == 1 and group_i == 0:
              class_4 = 2
            if label_i == 1 and group_i == 1:
              class_4 = 3             
            class_group_correct[class_4] += c.cpu().numpy()[i]
            class_group_total[class_4] += 1
        correct+=preds.eq(label.data.view_as(preds)).sum().item()
    if flag ==0:
      for i in range(4):
        if class_group_total[i] != 0:
          print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * class_group_correct[i] / class_group_total[i])
        else:
          print('No image')
    else:
      for i in range(4):
        if class_group_total[i] != 0:
          #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * float(class_group_correct[i]) / float(class_group_total[i]))
        else:
          print('No image')      
        '''
    for i in range(4):
      if class_group_total[i] != 0:
        #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
        print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
      else:
        print('No image')
        '''
    
    print(f'{text}:ACC = {float(correct)/float(total)*100.}')
    print(f'{text}:balance ACC = {np.mean(accuracy_4)}')

    return float(correct)/float(total)*100.,np.mean(accuracy_4),accuracy_4

    #return f'{text}: Epoch {cur_epoch} :  Loss = {loss/total}   ACC = {correct/total*100.}',loss/total,correct/total*100.

def eval_per_class_post(data_loader, model,model_vector, text,flag=0):
    model.eval()
    model_vector.eval()
    correct=0.
    total=0.
    loss=0.
    class_group_correct = list(0. for i in range(4))
    class_group_total = list(0. for i in range(4))
    accuracy_4 = []
    classes = ('00', '01', '10', '11')
    for cur_iter, (data, label,group) in enumerate(data_loader):
        label = label.long()
        label = label[:,0]
        group = group[:,0]
        #data, label,group = data.cuda(), label[:,0].cuda(non_blocking=True),group[:,0].cuda(non_blocking=True)
        logits = model(data)
        logist1 = model_vector(logits)
        #logist1 = model_vector(logits,group)
        
        loss_i = criterion(logist1,label,group,la)
        #print(loss_i)
        loss = loss+loss_i
        preds = logist1.data.max(1)[1]
        #logits = model(data)
        #preds = logits.data.max(1)[1]
        c = (label == preds).squeeze()
        #print(logits,preds, targets==preds)
        mb_size = data.size(0)
        # if not dy is None:
        #     print(my_cross_entropy(logits,labels,dy,ly))
        # if 'train' in text:
        #     loss += loss_fun(logits, labels,dy,ly ).item()*mb_size
        # else:
        #     loss += loss_fun(logits, labels).item()*mb_size
        total+=mb_size
        if mb_size>=1:
          for i in range(int(mb_size)):
            #label = preds[i].item()
            label_i = label[i].item()
            group_i = group[i].item()
            if label_i == 0 and group_i == 0:
              class_4 = 0
            if label_i == 0 and group_i == 1:
              class_4 = 1              
            if label_i == 1 and group_i == 0:
              class_4 = 2
            if label_i == 1 and group_i == 1:
              class_4 = 3             
            class_group_correct[class_4] += c.cpu().numpy()[i]
            class_group_total[class_4] += 1
        correct+=preds.eq(label.data.view_as(preds)).sum().item()
    print('loss',loss/cur_iter)
    if flag ==0:
      for i in range(4):
        if class_group_total[i] != 0:
          print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * class_group_correct[i] / class_group_total[i])
        else:
          print('No image')
    else:
      for i in range(4):
        if class_group_total[i] != 0:
          #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
          #print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
          accuracy_4.append(100 * float(class_group_correct[i]) / float(class_group_total[i]))
        else:
          print('No image')      
        '''
    for i in range(4):
      if class_group_total[i] != 0:
        #print('Accuracy of %5s : %2f %%' % (classes[i], 100 * class_group_correct[i] / class_group_total[i]))
        print('Total of %5s : %2f ' % (classes[i],class_group_total[i]))
      else:
        print('No image')
        '''
    
    print(f'{text}:ACC = {float(correct)/float(total)*100.}')
    print(f'{text}:balance ACC = {np.mean(accuracy_4)}')

    return float(correct)/float(total)*100.,np.mean(accuracy_4),accuracy_4

    #return f'{text}: Epoch {cur_epoch} :  Loss = {loss/total}   ACC = {correct/total*100.}',loss/total,correct/total*100.

data_test = np.empty((0,2))
labels_test = np.empty((0,))
groups_test = np.empty((0,))
model.eval()
for cur_iter, (data, labels,groups) in enumerate(test_my_dataloader):
  labels = labels.long()
  groups = groups.long()
  labels = labels[:,0]
  groups = groups[:,0]
  labels_test = np.hstack((labels_test,labels.numpy()))
  groups_test = np.hstack((groups_test,groups.numpy()))
  #data, targets = data.cuda(), targets.cuda(non_blocking=True)
  logits = model(data)
  data_test = np.vstack((data_test,logits.detach().numpy()))

# search for best balanced acc
acc_best = 0.01
entropy_best = 1
parameter_best = np.zeros((1,4))
for i1 in np.arange(-5.0,5.0,0.5):
  for i2 in np.arange(-5.0,5.0,0.2):
    for i3 in np.arange(-5.0,5.0,0.2):
      for i4 in np.arange(-5.0,5.0,0.2):
        logits_new = data_test*np.array([i1,i2])+np.array([i3,i4])
        preds = np.argmax(logits_new, axis=1)
        c = (labels_test == preds)
        #c = torch.FloatTensor(c)
        #one_hot = F.one_hot(torch.from_numpy(labels_test*2+groups_test).long())
        #acc_class = (c*one_hot.T).T
        #acc_class_mean = torch.mean(acc_class, dim=0)
        #print(c)
        #print(np.mean(c))
        #print(sum(c))
        acc_class_mean = np.mean(c)
        #print(acc_class_mean)
        #cross_entropy_class_mean_gap = abs(cross_entropy_class_mean[0]-cross_entropy_class_mean[1])     
        if acc_class_mean>=acc_best:
          if acc_class_mean == acc_best: 
            print('same',[i1,i2,i3,i4])
            print(acc_class_mean)
          parameter_best = [i1,i2,i3,i4]
          #entropy_best = F.cross_entropy(torch.from_numpy(logits_new),torch.from_numpy(labels_test).long())
          acc_best = acc_class_mean
print('best balanced acc parameter:',parameter_best)
print('best balanced acc:',acc_best)

class Vector(nn.Module):
    def __init__(self):
      super(Vector, self).__init__()
      self.vector1 = nn.Parameter(torch.FloatTensor([parameter_best[0],parameter_best[1]]))
      self.vector2 = nn.Parameter(torch.FloatTensor([parameter_best[2],parameter_best[3]]))
    def vector_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        vector1 = self.vector1
        vector2 = self.vector2
        return logits*vector1+vector2
    def forward(self, x):
        out1 = self.vector_scale(x)
        return out1
model_vector = Vector().to()

print('val')
    acc_val,class_correct_val,class_total_val=eval_per_class_post(val_my_dataloader,model,model_vector,'val')
    print('test')
    acc_test,class_correct_test,class_total_test=eval_per_class_post(test_my_dataloader,model,model_vector,'test')

# search for min DEO
acc_best = 1
entropy_best = 1
parameter_best = np.zeros((1,4))
for i1 in np.arange(-5.0,5.0,0.5):
  for i2 in np.arange(-5.0,5.0,0.2):
    for i3 in np.arange(-5.0,5.0,0.2):
      for i4 in np.arange(-5.0,5.0,0.2):
        logits_new = data_test*np.array([i1,i2])+np.array([i3,i4])
        preds = np.argmax(logits_new, axis=1)
        c = (labels_test == preds)
        c = torch.FloatTensor(c)
        one_hot = F.one_hot(torch.from_numpy(labels_test*2+groups_test).long())
        acc_class = (c*one_hot.T).T
        acc_class_mean = torch.mean(acc_class, dim=0)
        #print(c)
        #print(np.mean(c))
        #print(sum(c))
        #acc_class_mean = np.mean(c)
        #print(acc_class_mean)
        acc_gap = abs(acc_class_mean[0]-acc_class_mean[1])+abs(acc_class_mean[2]-acc_class_mean[3])      
        acc_loss = acc_gap*0.1+(1-torch.mean(c))*0.9
        acc_loss = acc_gap
        if acc_loss<=acc_best:
          if acc_loss == acc_best: 
            print('same',[i1,i2,i3,i4])
            print(acc_loss)
          parameter_best = [i1,i2,i3,i4]
          #entropy_best = F.cross_entropy(torch.from_numpy(logits_new),torch.from_numpy(labels_test).long())
          acc_best = acc_loss
print('best balanced acc parameter:',parameter_best)
print('best balanced acc:',acc_best)

class Vector(nn.Module):
    def __init__(self):
      super(Vector, self).__init__()
      self.vector1 = nn.Parameter(torch.FloatTensor([parameter_best[0],parameter_best[1]]))
      self.vector2 = nn.Parameter(torch.FloatTensor([parameter_best[2],parameter_best[3]]))
    def vector_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        vector1 = self.vector1
        vector2 = self.vector2
        return logits*vector1+vector2
    def forward(self, x):
        out1 = self.vector_scale(x)
        return out1
model_vector = Vector().to()

print('val')
    acc_val,class_correct_val,class_total_val=eval_per_class_post(val_my_dataloader,model,model_vector,'val')
    print('test')
    acc_test,class_correct_test,class_total_test=eval_per_class_post(test_my_dataloader,model,model_vector,'test')