import torch
from torch import nn
import torch.optim as optim
import pickle
import numpy as np
from all_approximators import *
from selector import SoftSelector
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--x_data_file", type=str, required=True, help="data embeddings file")
parser.add_argument("--targets_data_file", type=str, required=True)
parser.add_argument("--y_onehot_file", type=str, required=True)
parser.add_argument("--arch_embeddings_file", type=str, required=True)
parser.add_argument("--model_encoder_file", type=str, required=True)
parser.add_argument("--subset_size", type=str, default=500)
args = parser.parse_args()

torch.manual_seed(0)
np.random.seed(0)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print(device)

train_indices = list(range(101,151))
x_data = torch.load(args.x_data_file)
targets_data = torch.load(args.targets_data_file)

y_onehot = torch.load(args.y_onehot_file)
arch_embeddings = torch.load(args.arch_embeddings_file)
y_pred = torch.load(args.model_encoder_file)

loss_fn = nn.CrossEntropyLoss(reduction='none')

loss_with_ma = []
with torch.no_grad():
    for i2 in range(len(arch_embeddings)):
        loss_with_ma.append(loss_fn(y_pred[i2], targets_data).unsqueeze(0))
loss_with_ma = torch.cat(loss_with_ma, dim=0)

print("Started Training...")

subset_size =args.subset_size

selector = SoftSelector(targets_data, subset_size=subset_size, lambda1=0.1, lambda2=0.1, num_classes=10).to(device)
optimizer = optim.Adam(selector.parameters(), lr=0.001)

checkpoints = []
pi_list = 0.0
num_epochs = 20
sample_size = 20

for epoch in range(num_epochs):
    print(f"========== epoch: {epoch+1}==========")
    
    main_obj_func = 0.0
    optimizer.zero_grad()
    
    for i in range(len(arch_embeddings)):
        obj_fn, pi_xy, top_k = selector(arch_embeddings[i].to(device), y_pred[i].to(device), x_data.to(device), y_onehot.to(device), loss_with_ma[i].to(device))
        
        main_obj_func += obj_fn
        
        if epoch == num_epochs-1:
            pi_list += pi_xy.cpu().data.numpy()
    
    
    print(main_obj_func.item())
    checkpoints.append(main_obj_func.item())
    main_obj_func.backward()
    optimizer.step()


print('Finished Training')

print(pi_list.sum()/sample_size)
print(pi_xy.sum())
torch.save(selector.state_dict(), f"selector_weights_{subset_size}.pt")
