import numpy as np

import torch
import torch.utils.data
from torch.autograd import Variable

from model_lenet import RegressionModel, RegressionTrain
from pareto_utility import get_valid_direction, circle_points, circle_points_random

import pickle


##### set problem info ####
dataset = 'mnist'
n_tasks = 2

niter = 200
npref = 3
npref_test = 9

ref_vec_test = torch.tensor(circle_points([1], [npref_test])[0]).cuda().float()



##### load data ####

# MultiMNIST: multi_mnist.pickle
if dataset == 'mnist':
    with open('data/multi_mnist.pickle','rb') as f:
        trainX, trainLabel,testX, testLabel = pickle.load(f)  



trainX = torch.from_numpy(trainX.reshape(120000,1,36,36)).float()
trainLabel = torch.from_numpy(trainLabel).long()
testX = torch.from_numpy(testX.reshape(20000,1,36,36)).float()
testLabel = torch.from_numpy(testLabel).long()


train_set = torch.utils.data.TensorDataset(trainX, trainLabel)
test_set  = torch.utils.data.TensorDataset(testX, testLabel)


batch_size = 256
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader))) 


#### define MTL model ####

model = RegressionTrain(RegressionModel(n_tasks))
if torch.cuda.is_available():
    model.cuda()
    

#### define optimizer ####

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
    

#### run the algorithm for niter epoch ####
for t in range(niter):

    model.train()
    for (it, batch) in enumerate(train_loader):
        
        X = batch[0]
        ts = batch[1]
        if torch.cuda.is_available():
            X = X.cuda()
            ts = ts.cuda()
            
        # generate 1 random preference vector and (npref - 1) reference vectors at each iteration
        ref_vec = torch.tensor(circle_points_random([1], [npref])[0]).cuda().float()
        pref_idx = np.random.randint(npref)
        
        # obtain and store the gradient for each task
        grads = {}
        losses_vec = []
        
        task_loss =  model(X, ts,pref_idx,ref_vec[pref_idx]) 
        
        for i in range(n_tasks):
            losses_vec.append(task_loss[i].data)
           
            optimizer.zero_grad()
            task_loss[i].backward(retain_graph=True)
        
            grads[i] = []
            for param in model.parameters():
                if param.grad is not None:
                    grads[i].append(Variable(param.grad.data.clone().flatten(), requires_grad=False))
            
           
        grads_list = [torch.cat(grads[i]) for i in range(len(grads))]
        grads = torch.stack(grads_list)
        
        losses_vec = torch.stack(losses_vec)
        
        #### calcuate a valid gradient direction, turn it into weights ####
        weight_vec = get_valid_direction(grads,losses_vec,ref_vec,pref_idx)
        
        normalize_coeff = n_tasks / torch.sum(weight_vec)
        weight_vec = weight_vec * normalize_coeff
        
        #### opt step ####
        optimizer.zero_grad()
        for i in range(len(task_loss)):
            if i == 0:
                loss_total =  weight_vec[i] * task_loss[i]
            else:
                loss_total = loss_total + weight_vec[i] * task_loss[i]
            
        loss_total.backward()
        optimizer.step()   
        
        #### trick for diversity ####
        #### small chance to directly optimize the end point ####
        r0 = np.random.rand(1)
        r1 = np.random.rand(1)
        
        if r0 < 0.1:
          
            optimizer.zero_grad()
            
            ref_vec_temp = torch.tensor([1,0]).cuda().float()
            task_loss = model(X, ts,0,ref_vec_temp) 
            loss_total = 1 * task_loss[1]
            
            loss_total.backward()
            optimizer.step()    
                
        if r1 < 0.1:
            optimizer.zero_grad()
            
            ref_vec_temp = torch.tensor([0,1]).cuda().float()
            task_loss = model(X, ts,0,ref_vec_temp) 
            loss_total = 1 * task_loss[0]
            
            loss_total.backward()
            optimizer.step()    
        
        
    scheduler.step()
        
            
    # check performance every 10 iterations
    if t == 0 or (t+1)%10 == 0:
        
        model.eval()
        
        task_train_losses = []
        train_accs = []
    
        task_test_losses = []
        test_accs = []
        
        for pref_idx_test in range(npref_test):
           
            with torch.no_grad():
      
                #### train loss and acc ####
                total_train_loss = []
                train_acc = []
        
                correct1_train = 0
                correct2_train = 0
                
                for (it, batch) in enumerate(train_loader):
                    X = batch[0]
                    ts = batch[1]
                    if torch.cuda.is_available():
                        X = X.cuda()
                        ts = ts.cuda()
        
                
                    valid_train_loss = model(X, ts,pref_idx_test,ref_vec_test[pref_idx_test])
                    total_train_loss.append(valid_train_loss)
                    output1 = model.model(X,pref_idx_test,ref_vec_test[pref_idx_test]).max(2, keepdim=True)[1][:,0]
                    output2 = model.model(X,pref_idx_test,ref_vec_test[pref_idx_test]).max(2, keepdim=True)[1][:,1]
                  
                    correct1_train += output1.eq(ts[:,0].view_as(output1)).sum().item()
                    correct2_train += output2.eq(ts[:,1].view_as(output2)).sum().item()
                    
                train_acc = np.stack([1.0 * correct1_train / len(train_loader.dataset),1.0 * correct2_train / len(train_loader.dataset)])
        
                total_train_loss = torch.stack(total_train_loss)
                average_train_loss = torch.mean(total_train_loss, dim = 0)
                
                #### test loss and acc ####
                total_test_loss = []
                test_acc = []
        
                correct1_test = 0
                correct2_test = 0
                
                for (it, batch) in enumerate(test_loader):
                    X = batch[0]
                    ts = batch[1]
                    if torch.cuda.is_available():
                        X = X.cuda()
                        ts = ts.cuda()
        
                
                    valid_test_loss = model(X, ts,pref_idx_test,ref_vec_test[pref_idx_test])
                    total_test_loss.append(valid_test_loss)
                    output1 = model.model(X,pref_idx_test,ref_vec_test[pref_idx_test]).max(2, keepdim=True)[1][:,0]
                    output2 = model.model(X,pref_idx_test,ref_vec_test[pref_idx_test]).max(2, keepdim=True)[1][:,1]
                  
                    correct1_test += output1.eq(ts[:,0].view_as(output1)).sum().item()
                    correct2_test += output2.eq(ts[:,1].view_as(output2)).sum().item()
                    
                    
                test_acc = np.stack([1.0 * correct1_test / len(test_loader.dataset),1.0 * correct2_test / len(test_loader.dataset)])
        
                total_test_loss = torch.stack(total_test_loss)
                average_test_loss = torch.mean(total_test_loss, dim = 0)
                
            
            # record and print
            if torch.cuda.is_available():
                
                task_train_losses.append(average_train_loss.data.cpu().numpy())
                train_accs.append(train_acc)
                
                task_test_losses.append(average_test_loss.data.cpu().numpy())
                test_accs.append(test_acc)
                
                
                print('{}/{}: pref_vecs={}, train_loss={}, train_acc={},test_loss={}, test_acc={}'.format(
                        t + 1, niter,  ref_vec_test[pref_idx_test].data.cpu().numpy(), task_train_losses[-1],train_accs[-1], task_test_losses[-1],test_accs[-1]))          
               
        print("************************************************************")

