import numpy as np

import torch
import torch.utils.data
from torch.autograd import Variable

from create_dataset import cityscapes

from model_hyper_segnet import HyperSegModel, SegNet

from pareto_utility import get_d_paretomtl, circle_points, circle_points_random 
from pareto_utility import uniform_points, circle_points_random_from_uniform

# hyperparameters
n_tasks = 2
ref_type = 'unif'
npref = 3
npref_test = 9
niter = 200
batch_size = 12

# generate npref preference vectors with method ref_type
if ref_type == 'unif':
    ref_vec_test = torch.tensor(uniform_points(9)).cuda().float()
if ref_type == 'circle':
    ref_vec_test = torch.tensor(circle_points([1], [9])[0]).cuda().float()


# define dataset path
# please download the cityscapes dataset from: https://github.com/lorenmt/mtan
# and put it in ./data 
dataset_path = 'data/cityscapes2'
cityscapes_train_set = cityscapes(root=dataset_path, train=True)
cityscapes_test_set = cityscapes(root=dataset_path, train=False)


# data loaders
train_loader = torch.utils.data.DataLoader(
    dataset=cityscapes_train_set,
    batch_size=batch_size,
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    dataset=cityscapes_test_set,
    batch_size=batch_size,
    shuffle=False)

# print dataset information
print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader))) 

# initialize the hypernetwork-based segnet
model = HyperSegModel(SegNet())
if torch.cuda.is_available():
    model.cuda()
    
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)    

# training the model for niter epochs
for t in range(niter):

    model.train()
   
    for (it, batch) in enumerate(train_loader):
        
        X = batch[0]
        ts1 = batch[1]
        ts2 = batch[2]
        
        if torch.cuda.is_available():
            X = X.cuda()
            ts1 = ts1.type(torch.LongTensor).cuda()
            ts2 = ts2.cuda()

        ts = [ts1,ts2]
        
        # randomly generate npref preference vector at each iteration with method ref_type
        if ref_type == 'unif':
            ref_vec = torch.tensor(circle_points_random_from_uniform(npref)).cuda().float()
        if ref_type == 'circle':
            ref_vec = torch.tensor(circle_points_random([1], [npref])[0]).cuda().float()
        pref_idx = np.random.randint(npref)
        
        # obtain and store the losses and gradients for each task 
        task_loss =  model(X, ts,ref_vec[pref_idx]) 
        losses_vec = []
        grads = {}
       
        for i in range(n_tasks):
            losses_vec.append(task_loss[i].data)
            
            optimizer.zero_grad()
            task_loss[i].backward(retain_graph=True)
        
            grads[i] = []
            for param in model.parameters():
                if param.grad is not None:
                    grads[i].append(Variable(param.grad.data.clone().flatten(), requires_grad=False))
        
        losses_vec = torch.stack(losses_vec)
        grads_list = [torch.cat(grads[i]) for i in range(len(grads))]
        grads = torch.stack(grads_list)
        
        # normalize the ref_vec into unit vector, for unif method
        ref_vec = ref_vec / torch.norm(ref_vec, dim = 1).view(len(ref_vec),1)
        
        # calculate the weight for each task
        weight_vec = get_d_paretomtl(grads,losses_vec,ref_vec,pref_idx)
        
        # normalize the weight 
        normalize_coeff = n_tasks / torch.sum(torch.abs(weight_vec))
        weight_vec = weight_vec * normalize_coeff
        
        optimizer.zero_grad()
        for i in range(len(task_loss)):
            if i == 0:
                loss_total =  weight_vec[i] * task_loss[i]
            else:
                loss_total = loss_total + weight_vec[i] * task_loss[i]
        
        loss_total.backward()
        optimizer.step()     
        
        # optional: optimize the model with preference on each task with prob_task
        # can be treated as biased distribution for preference sampling
        # large prob_task at the begining could be good for expand the trade-off curve
        # TODO: better sampling methods
        prob_task = 0.5
        r0 = np.random.rand(1)
        r1 = np.random.rand(1)
        
        # preference on task 0
        if r0 < prob_task:
            optimizer.zero_grad()
            # in pareto mtl, 0 is for the preferred task
            ref_vec_temp = torch.tensor([0,1]).cuda().float()
            task_loss = model(X, ts,ref_vec_temp) 
            loss_total = n_tasks * task_loss[0]
            
            loss_total.backward()
            optimizer.step()    
        
        # preference on task 1
        if r1 < prob_task:
            optimizer.zero_grad()
            # in pareto mtl, 0 is for the preferred task
            ref_vec_temp = torch.tensor([1,0]).cuda().float()
            task_loss = model(X, ts,ref_vec_temp) 
            loss_total = n_tasks * task_loss[1]
            
            loss_total.backward()
            optimizer.step()    
            
    scheduler.step()
    
    
    # print the performance for every 5 epochs
    if t == 0 or (t+1)%5 == 0:
        model.eval()
        
        # performance on the train set
        for pref_idx_test in range(npref_test):
                    
            with torch.no_grad():
                n_train_batch = len(train_loader)
                cost = np.zeros(6, dtype=np.float32)
                avg_cost = np.zeros(6, dtype=np.float32)
                
                for (it, batch) in enumerate(train_loader):
                    X = batch[0]
                    ts1 = batch[1]
                    ts2 = batch[2]
                    
                    if torch.cuda.is_available():
                        X = X.cuda()
                        ts1 = ts1.type(torch.LongTensor).cuda()
                        ts2 = ts2.cuda()
            
                    ts = [ts1,ts2]
                
                    test_pred = model.model(X,ref_vec_test[pref_idx_test])
                    test_loss = model.model.model_fit(test_pred[0], ts[0], test_pred[1], ts[1])
                    
                    cost[0] = test_loss[0].item()
                    cost[1] = model.model.compute_miou(test_pred[0], ts[0]).item()
                    cost[2] = model.model.compute_iou(test_pred[0], ts[0]).item()
                    cost[3] = test_loss[1].item()
                    cost[4], cost[5] = model.model.depth_error(test_pred[1], ts[1])
                    
                    avg_cost += cost / n_train_batch
                    
                # print results
                print('{}/{}: Training pref_vecs= [{:.4f},{:.4f}] | Task1: {:.4f} {:.4f} {:.4f} | Task2: {:.4f} {:.4f} {:.4f} |'.format(
                        t + 1, niter,  ref_vec_test[pref_idx_test].data.cpu().numpy()[0],ref_vec_test[pref_idx_test].data.cpu().numpy()[1], 
                        avg_cost[0], avg_cost[1], avg_cost[2], avg_cost[3], avg_cost[4], avg_cost[5]))                      
                
        print("************************************************************")
        
        # performance on the test set
        for pref_idx_test in range(npref_test):
                    
            with torch.no_grad():
                n_test_batch = len(test_loader)
                cost = np.zeros(6, dtype=np.float32)
                avg_cost = np.zeros(6, dtype=np.float32)
                
                for (it, batch) in enumerate(test_loader):
                    X = batch[0]
                    ts1 = batch[1]
                    ts2 = batch[2]
                    
                    if torch.cuda.is_available():
                        X = X.cuda()
                        ts1 = ts1.type(torch.LongTensor).cuda()
                        ts2 = ts2.cuda()
            
                    ts = [ts1,ts2]
                
                    test_pred = model.model(X,ref_vec_test[pref_idx_test])
                    test_loss = model.model.model_fit(test_pred[0], ts[0], test_pred[1], ts[1])
                    
                    cost[0] = test_loss[0].item()
                    cost[1] = model.model.compute_miou(test_pred[0], ts[0]).item()
                    cost[2] = model.model.compute_iou(test_pred[0], ts[0]).item()
                    cost[3] = test_loss[1].item()
                    cost[4], cost[5] = model.model.depth_error(test_pred[1], ts[1])
                    
                    avg_cost += cost / n_test_batch
            
            # print results
            print('{}/{}: Testing pref_vecs= [{:.4f},{:.4f}] | Task1: {:.4f} {:.4f} {:.4f} | Task2: {:.4f} {:.4f} {:.4f} |'.format(
                    t + 1, niter,  ref_vec_test[pref_idx_test].data.cpu().numpy()[0],ref_vec_test[pref_idx_test].data.cpu().numpy()[1], 
                    avg_cost[0], avg_cost[1], avg_cost[2], avg_cost[3], avg_cost[4], avg_cost[5]))       
            
        print("************************************************************")
