import math

import numpy as np
import time
import torch
import torch.nn.functional as F
import gc
import random
import wandb

from sklearn import random_projection

from copy import deepcopy
from min_norm_solvers import MinNormSolver
from scipy.optimize import minimize, Bounds, minimize_scalar
import  matplotlib
from collections import OrderedDict
matplotlib.use('Agg')
import matplotlib.pyplot as plt


"""
Define task metrics, loss functions and model trainer here.
"""
def get_variable_name(variable, loc):
    for k, v in loc.items():
        if loc[k] == variable:
            return k

class Averager():

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.data = []

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def add(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.data.append(val)

    def item(self):
        return self.avg

    def obtain_data(self):
        return self.data

    def __len__(self):
        return len(self.data)

class Running_Avg():
    def __init__(self, max_cnt):
        self.max_cnt = max_cnt
        self.data = []
        self.cnt = 0

    def add(self, val):
        self.data.append(val)
        self.cnt += 1

    def item(self):
        if self.cnt < self.max_cnt:
            d = self.data
        else:
            d = self.data[-self.max_cnt:]
        return sum(d) / len(d)

    def std(self):
        if self.cnt < self.max_cnt:
            d = self.data
        else:
            d = self.data[-self.max_cnt:]

        return np.std(d)

class Timer():
    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return f'{x / 3600:.1f}h'
        if x >= 60:
            return f'{round(x / 60)}m'
        return f'{x}s'

class ConfMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, pred, target):
        n = self.num_classes
        if self.mat is None:
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
        with torch.no_grad():
            k = (target >= 0) & (target < n)
            inds = n * target[k].to(torch.int64) + pred[k]
            self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

    def get_metrics(self):
        h = self.mat.float()
        acc = torch.diag(h).sum() / h.sum()
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return torch.mean(iu), acc

def cos(t1, t2):
    t1 = F.normalize(t1, dim=0)
    t2 = F.normalize(t2, dim=0)

    dot = (t1 * t2).sum(dim=0)

    return dot

def dot(t1, t2):
    dot = (t1 * t2).sum(dim=0)

    return dot

def pair_cos_with_d(pair, d):
    length = pair.size(0)

    dot_value = [cos(d, pair[i]) for i in range(length)]

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_dot_with_d(pair, d):
    length = pair.size(0)

    dot_value = [dot(d, pair[i]) for i in range(length)]

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_cos(pair):
    length = pair.size(0)

    dot_value = []
    for i in range(length - 1):
        for j in range(i + 1, length):
           dot_value.append(cos(pair[i], pair[j]))

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_dot(pair):
    length = pair.size(0)

    dot_value = []
    for i in range(length - 1):
        for j in range(i + 1, length):
            dot_value.append(dot(pair[i], pair[j]))

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def remove_layer_gradient(grad_list, pos):
    if not isinstance(pos, list):
        pos = [pos]
    g_list = []
    for i, g in enumerate(grad_list):
        if i not in pos:
            g_list.append(g)

    g = torch.cat(g_list)
    return g

def LFT_min_cos(grad_all, grad_target, n_tasks, pos):
    min_cos = 1000

    g_target = remove_layer_gradient(grad_target, pos)
    for i in range(n_tasks):
        g_task = remove_layer_gradient(grad_all[i], pos)
        cos_v = cos(g_task, g_target)
        if cos_v < min_cos:
            min_cos = cos_v

    return min_cos

def LFT_task_cos_diff(grad_all, grad_target, n_tasks, pos):
    diff_list = []
    for i in range(n_tasks):
        g_task_previous = torch.cat(grad_all[i])
        g_previous = torch.cat(grad_target)
        cos_v_previous = cos(g_task_previous, g_previous)

        g_task = remove_layer_gradient(grad_all[i], pos)
        g = remove_layer_gradient(grad_target, pos)
        cos_v = cos(g_task, g)

        diff = cos_v - cos_v_previous
        diff_list.append(diff)

    diff_list = torch.stack(diff_list).view(-1)

    return diff_list

def LFT_cos_diff(grad_all, n_tasks, pos):
    diff_list = []
    for i in range(n_tasks - 1):
        for j in range(i + 1, n_tasks):
            g_1 = torch.cat(grad_all[i])
            g_2 = torch.cat(grad_all[j])
            cos_v_previous = cos(g_1, g_2)

            g_1 = remove_layer_gradient(grad_all[i], pos)
            g_2 = remove_layer_gradient(grad_all[j], pos)

            cos_v = cos(g_1, g_2)

            diff = cos_v - cos_v_previous
            diff_list.append(diff)

    diff_list = torch.stack(diff_list).view(-1)

    return diff_list

def topK_prob(task_diff, topK):
    diff_dict = OrderedDict()
    for key, value in task_diff.items():
        l = torch.stack(value)
        n_task = l.size(1)
        gt = (l > 0)
        gt = gt.sum(dim=1)
        gt = (gt == n_task)
        prob = gt.sum() / gt.size(0)
        diff_dict[key] = prob

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1], reverse=True)}

    output = list(diff_dict.keys())[:topK]
    #
    # print('-------------topK_prob---------------')
    # print(diff_dict)

    return output

def topK_value(task_diff, topK):
    diff_dict = OrderedDict()

    for key, value in task_diff.items():
        l = torch.stack(value)
        score = l.mean()
        diff_dict[key] = score

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1])}

    output = list(diff_dict.keys())[:topK]

    # print('-------------topK_value---------------')
    # print(diff_dict)

    return output

def avg_tensor(list, dim=0):
    t = torch.stack(list)
    mean = torch.mean(t, dim=dim)
    return mean

def quantile_topK(dict, q, topK, interpolation='linear'):
    q_angle = OrderedDict()
    for key, value in dict.items():
        mean = avg_tensor(value, dim=1)
        quantile_value = torch.quantile(mean, q, interpolation=interpolation)
        q_angle[key] = quantile_value

    q_angle = {k: v for k, v in sorted(q_angle.items(), key=lambda item: item[1])}

    output = list(q_angle.keys())[:topK]

    return output

def define_wandb_metric():
    wandb.define_metric("epoch")
    wandb.define_metric("Train_SegLoss", step_metric="epoch")
    wandb.define_metric("Train_mIoU", step_metric="epoch")
    wandb.define_metric("Train_Pix_Acc", step_metric="epoch")
    wandb.define_metric("Train_DepthLoss", step_metric="epoch")
    wandb.define_metric("Train_Abs_Err", step_metric="epoch")
    wandb.define_metric("Train_Rel_Err", step_metric="epoch")
    wandb.define_metric("Train_NormalLoss", step_metric="epoch")
    wandb.define_metric("Train_AD_mean", step_metric="epoch")
    wandb.define_metric("Train_AD_median", step_metric="epoch")
    wandb.define_metric("Train_WithinT_11d25", step_metric="epoch")
    wandb.define_metric("Train_WithinT_22d5", step_metric="epoch")
    wandb.define_metric("Train_WithinT_30", step_metric="epoch")

    wandb.define_metric("avg_t1_angle", step_metric="epoch")
    wandb.define_metric("avg_t2_angle", step_metric="epoch")
    wandb.define_metric("avg_t3_angle", step_metric="epoch")

    wandb.define_metric("Test_SegLoss", step_metric="epoch")
    wandb.define_metric("Test_mIoU", step_metric="epoch")
    wandb.define_metric("Test_Pix_Acc", step_metric="epoch")
    wandb.define_metric("Test_DepthLoss", step_metric="epoch")
    wandb.define_metric("Test_Abs_Err", step_metric="epoch")
    wandb.define_metric("Test_Rel_Err", step_metric="epoch")
    wandb.define_metric("Test_NormalLoss", step_metric="epoch")
    wandb.define_metric("Test_AD_mean", step_metric="epoch")
    wandb.define_metric("Test_AD_median", step_metric="epoch")
    wandb.define_metric("Test_WithinT_11d25", step_metric="epoch")
    wandb.define_metric("Test_WithinT_22d5", step_metric="epoch")
    wandb.define_metric("Test_WithinT_30", step_metric="epoch")

    wandb.define_metric('inner_index')
    wandb.define_metric('t1_angle', step_metric='inner_index')
    wandb.define_metric('t2_angle', step_metric='inner_index')
    wandb.define_metric('t3_angle', step_metric='inner_index')
    wandb.define_metric('t1_opposite', step_metric='inner_index')
    wandb.define_metric('t2_opposite', step_metric='inner_index')
    wandb.define_metric('t3_opposite', step_metric='inner_index')

def define_wandb_metric_single_task(task_id):
    wandb.define_metric("epoch")

    if task_id == 0:
        wandb.define_metric("Train_SegLoss", step_metric="epoch")
        wandb.define_metric("Train_mIoU", step_metric="epoch")
        wandb.define_metric("Train_Pix_Acc", step_metric="epoch")

        wandb.define_metric("Test_SegLoss", step_metric="epoch")
        wandb.define_metric("Test_mIoU", step_metric="epoch")
        wandb.define_metric("Test_Pix_Acc", step_metric="epoch")

    if task_id == 1:
        wandb.define_metric("Train_DepthLoss", step_metric="epoch")
        wandb.define_metric("Train_Abs_Err", step_metric="epoch")
        wandb.define_metric("Train_Rel_Err", step_metric="epoch")

        wandb.define_metric("Test_DepthLoss", step_metric="epoch")
        wandb.define_metric("Test_Abs_Err", step_metric="epoch")
        wandb.define_metric("Test_Rel_Err", step_metric="epoch")

    if task_id == 2:
        wandb.define_metric("Train_NormalLoss", step_metric="epoch")
        wandb.define_metric("Train_AD_mean", step_metric="epoch")
        wandb.define_metric("Train_AD_median", step_metric="epoch")
        wandb.define_metric("Train_WithinT_11d25", step_metric="epoch")
        wandb.define_metric("Train_WithinT_22d5", step_metric="epoch")
        wandb.define_metric("Train_WithinT_30", step_metric="epoch")

        wandb.define_metric("Test_NormalLoss", step_metric="epoch")
        wandb.define_metric("Test_AD_mean", step_metric="epoch")
        wandb.define_metric("Test_AD_median", step_metric="epoch")
        wandb.define_metric("Test_WithinT_11d25", step_metric="epoch")
        wandb.define_metric("Test_WithinT_22d5", step_metric="epoch")
        wandb.define_metric("Test_WithinT_30", step_metric="epoch")

def print_details_single_tasks(avg_cost, index, task_id, mode='train'):
    if task_id == 0:
        if mode == 'train':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f}'
                  .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2]))

            log_dict = {
                'epoch': index,
                'Train_SegLoss': avg_cost[index, 0],
                'Train_mIoU': avg_cost[index, 1],
                'Train_Pix_Acc': avg_cost[index, 2]
            }

            wandb.log(log_dict)

        elif mode == 'val':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} TEST: {:.4f} {:.4f} {:.4f}'
              .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 6],
                      avg_cost[index, 7], avg_cost[index, 8]))

            log_dict = {
                'epoch': index,
                'Train_SegLoss': avg_cost[index, 0],
                'Train_mIoU': avg_cost[index, 1],
                'Train_Pix_Acc': avg_cost[index, 2],
                'Test_SegLoss': avg_cost[index, 6],
                'Test_mIoU': avg_cost[index, 7],
                'Test_Pix_Acc': avg_cost[index, 8]
            }

            wandb.log(log_dict)

    elif task_id == 1:
        if mode == 'train':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f}'
                  .format(index, avg_cost[index, 3], avg_cost[index, 4], avg_cost[index, 5]))

            log_dict = {
                'epoch': index,
                'Train_DepthLoss': avg_cost[index, 3],
                'Train_Abs_Err': avg_cost[index, 4],
                'Train_Rel_Err': avg_cost[index, 5]
            }

            wandb.log(log_dict)

        elif mode == 'val':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} TEST: {:.4f} {:.4f} {:.4f}'
              .format(index, avg_cost[index, 3], avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 9],
                      avg_cost[index, 10], avg_cost[index, 11]))

            log_dict = {
                'epoch': index,
                'Train_DepthLoss': avg_cost[index, 3],
                'Train_Abs_Err': avg_cost[index, 4],
                'Train_Rel_Err': avg_cost[index, 5],
                'Test_DepthLoss': avg_cost[index, 9],
                'Test_Abs_Err': avg_cost[index, 10],
                'Test_Rel_Err': avg_cost[index, 11]
            }

            wandb.log(log_dict)

    else:
        raise ValueError('Error')

def print_details(avg_cost, index, avg_t1_angle, avg_t2_angle, mode='train'):
    if mode == 'train':
        print(
            'Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'
                .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 3],
                        avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 6], avg_cost[index, 7],
                        avg_cost[index, 8],
                        avg_cost[index, 9], avg_cost[index, 10], avg_cost[index, 11]))

        log_dict = {
            'epoch': index,
            'Train_SegLoss': avg_cost[index, 0],
            'Train_mIoU': avg_cost[index, 1],
            'Train_Pix_Acc': avg_cost[index, 2],
            'Train_DepthLoss': avg_cost[index, 3],
            'Train_Abs_Err': avg_cost[index, 4],
            'Train_Rel_Err': avg_cost[index, 5],
            'avg_t1_angle': avg_t1_angle.item(),
            'avg_t2_angle': avg_t2_angle.item()
        }

        wandb.log(log_dict)

    elif mode == 'val':
        print(
            'Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} ||'
            'TEST: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f}'
                .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 3],
                        avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 6],
                        avg_cost[index, 7] * 100,
                        avg_cost[index, 8] * 100,
                        avg_cost[index, 9], avg_cost[index, 10], avg_cost[index, 11]))

        log_dict = {
            'epoch': index,
            'Train_SegLoss': avg_cost[index, 0],
            'Train_mIoU': avg_cost[index, 1],
            'Train_Pix_Acc': avg_cost[index, 2],
            'Train_DepthLoss': avg_cost[index, 3],
            'Train_Abs_Err': avg_cost[index, 4],
            'Train_Rel_Err': avg_cost[index, 5],
            'Test_SegLoss': avg_cost[index, 6],
            'Test_mIoU': avg_cost[index, 7],
            'Test_Pix_Acc': avg_cost[index, 8],
            'Test_DepthLoss': avg_cost[index, 9],
            'Test_Abs_Err': avg_cost[index, 10],
            'Test_Rel_Err': avg_cost[index, 11],
            'avg_t1_angle': avg_t1_angle.item(),
            'avg_t2_angle': avg_t2_angle.item()
        }

        wandb.log(log_dict)
    else:
        raise ValueError('Error')

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def model_fit(x_pred, x_output, task_type):
    # binary mark to mask out undefined pixel space
    binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).cuda()

    if task_type == 'semantic':
        # semantic loss: depth-wise cross entropy

        # x_pred = F.log_softmax(x_pred, dim=1)
        # torch.use_deterministic_algorithms(False)
        # loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1)
        # loss = loss_func(x_pred, x_output)

        x_pred = F.log_softmax(x_pred, dim=1)
        loss = F.nll_loss(x_pred, x_output, ignore_index=-1)
        # print(loss)

    if task_type == 'depth':
        # depth loss: l1 norm
        loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)

    return loss

def model_fit_top_q(x_pred, x_output, task_type, top_q=-1, std_b=0.0):
    binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).cuda()
    num_samples = x_pred.size(0)

    if task_type == 'semantic':
        # semantic loss: depth-wise cross entropy
        x_pred = F.log_softmax(x_pred, dim=1)
        loss = F.nll_loss(x_pred, x_output, ignore_index=-1, reduction='none')
        loss = loss.sum(dim=[1, 2]) / ((x_output != -1).sum() / num_samples)
        std = torch.std(loss, unbiased=False)
        if std > std_b:
            loss, indices = torch.topk(loss, k=top_q)
            print('111')
        else:
            loss = torch.mean(loss)

    elif task_type == 'depth':
        num_nonzero = binary_mask.sum()
        loss = torch.abs(x_pred - x_output) * binary_mask
        loss = torch.sum(loss, dim=[1, 2, 3]) / (num_nonzero / num_samples)

        std = torch.std(loss, unbiased=False)
        if std > std_b:
            loss, indices = torch.topk(loss, k=top_q)
            print('11')
        else:
            loss = torch.mean(loss)

    elif task_type == 'normal':
        # normal loss: dot product
        num_nonzero = binary_mask.sum()
        loss = 1 - torch.sum((x_pred * x_output) * binary_mask, dim=[1, 2, 3]) / (num_nonzero / num_samples)

        std = torch.std(loss, unbiased=False)
        if std > std_b:
            loss, indices = torch.topk(loss, k=top_q)
            print('1')
        else:
            loss = torch.mean(loss)

    else:
        loss = None

    return loss

def model_fit_per_samples(x_pred, x_output, task_type):
    # binary mark to mask out undefined pixel space
    binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).cuda()
    num_samples = x_pred.size(0)

    if task_type == 'semantic':
        # semantic loss: depth-wise cross entropy
        loss = F.nll_loss(x_pred, x_output, ignore_index=-1, reduction='none')
        loss = loss.sum(dim=[1, 2]) / ((x_output != -1).sum() / num_samples)
        # loss2 = F.nll_loss(x_pred, x_output, ignore_index=-1)
        # loss_temp = loss.mean()
        # if loss_temp != loss2:
        #     raise ValueError('Error')

    if task_type == 'depth':
        num_nonzero = binary_mask.sum()
        loss = torch.abs(x_pred - x_output) * binary_mask
        loss = torch.sum(loss, dim=[1, 2, 3]) / (num_nonzero / num_samples)

        # loss2 = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)
        # loss_temp = loss.mean()
        # if loss_temp != loss2:
        #     raise ValueError('Error')
    if task_type == 'normal':
        # normal loss: dot product
        num_nonzero = binary_mask.sum()
        loss = 1 - torch.sum((x_pred * x_output) * binary_mask, dim=[1,2,3]) / (num_nonzero / num_samples)

        # loss2 = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)
        # loss_temp = loss.mean()
        # if loss_temp != loss2:
        #     raise ValueError('Error')
    return loss

def depth_error(x_pred, x_output):
    binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).cuda()
    x_pred_true = x_pred.masked_select(binary_mask)
    x_output_true = x_output.masked_select(binary_mask)
    abs_err = torch.abs(x_pred_true - x_output_true)
    rel_err = torch.abs(x_pred_true - x_output_true) / torch.abs(x_output_true)
    return (torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item(), \
           (torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item()

"""
=========== Universal Multi-task Trainer ===========
"""

# def multi_task_trainer(train_loader, test_loader, multi_task_model, device, optimizer, scheduler, opt, total_epoch=200):
#     start_time = time.time()
#     train_batch = len(train_loader)
#     test_batch = len(test_loader)
#     T = opt.temp
#     avg_cost = np.zeros([total_epoch, 24], dtype=np.float32)
#     lambda_weight = np.ones([3, total_epoch])
#     for index in range(total_epoch):
#         epoch_start_time = time.time()
#         cost = np.zeros(24, dtype=np.float32)
#
#         # apply Dynamic Weight Average
#         if opt.weight == 'dwa':
#             if index == 0 or index == 1:
#                 lambda_weight[:, index] = 1.0
#             else:
#                 w_1 = avg_cost[index - 1, 0] / avg_cost[index - 2, 0]
#                 w_2 = avg_cost[index - 1, 3] / avg_cost[index - 2, 3]
#                 w_3 = avg_cost[index - 1, 6] / avg_cost[index - 2, 6]
#                 lambda_weight[0, index] = 3 * np.exp(w_1 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
#                 lambda_weight[1, index] = 3 * np.exp(w_2 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
#                 lambda_weight[2, index] = 3 * np.exp(w_3 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
#
#         # iteration for all batches
#         multi_task_model.train()
#         train_dataset = iter(train_loader)
#         conf_mat = ConfMatrix(multi_task_model.class_nb)
#         for k in range(train_batch):
#             train_data, train_label, train_depth, train_normal = train_dataset.next()
#             train_data, train_label = train_data.to(device), train_label.long().to(device)
#             train_depth, train_normal = train_depth.to(device), train_normal.to(device)
#
#             train_pred, logsigma = multi_task_model(train_data)
#
#             optimizer.zero_grad()
#             train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
#                           model_fit(train_pred[1], train_depth, 'depth'),
#                           model_fit(train_pred[2], train_normal, 'normal')]
#
#             if opt.weight == 'equal' or opt.weight == 'dwa':
#                 loss = sum([lambda_weight[i, index] * train_loss[i] for i in range(3)])
#                 #loss = sum([w[i] * train_loss[i] for i in range(3)])
#             else:
#                 loss = sum(1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2 for i in range(3))
#
#             loss.backward()
#             optimizer.step()
#
#             # accumulate label prediction for every pixel in training images
#             conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten())
#
#             cost[0] = train_loss[0].item()
#             cost[3] = train_loss[1].item()
#             cost[4], cost[5] = depth_error(train_pred[1], train_depth)
#             cost[6] = train_loss[2].item()
#             cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)
#             avg_cost[index, :12] += cost[:12] / train_batch
#
#         # compute mIoU and acc
#         avg_cost[index, 1:3] = conf_mat.get_metrics()
#
#         # evaluating test data
#         multi_task_model.eval()
#         conf_mat = ConfMatrix(multi_task_model.class_nb)
#         with torch.no_grad():  # operations inside don't track history
#             test_dataset = iter(test_loader)
#             for k in range(test_batch):
#                 test_data, test_label, test_depth, test_normal = test_dataset.next()
#                 test_data, test_label = test_data.to(device), test_label.long().to(device)
#                 test_depth, test_normal = test_depth.to(device), test_normal.to(device)
#
#                 test_pred, _ = multi_task_model(test_data)
#                 test_loss = [model_fit(test_pred[0], test_label, 'semantic'),
#                              model_fit(test_pred[1], test_depth, 'depth'),
#                              model_fit(test_pred[2], test_normal, 'normal')]
#
#                 conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten())
#
#                 cost[12] = test_loss[0].item()
#                 cost[15] = test_loss[1].item()
#                 cost[16], cost[17] = depth_error(test_pred[1], test_depth)
#                 cost[18] = test_loss[2].item()
#                 cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(test_pred[2], test_normal)
#                 avg_cost[index, 12:] += cost[12:] / test_batch
#
#             # compute mIoU and acc
#             avg_cost[index, 13:15] = conf_mat.get_metrics()
#
#         scheduler.step()
#         epoch_end_time = time.time()
#         print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} ||'
#             'TEST: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} | {:.4f}'
#             .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 3],
#                     avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 6], avg_cost[index, 7], avg_cost[index, 8],
#                     avg_cost[index, 9], avg_cost[index, 10], avg_cost[index, 11], avg_cost[index, 12], avg_cost[index, 13],
#                     avg_cost[index, 14], avg_cost[index, 15], avg_cost[index, 16], avg_cost[index, 17], avg_cost[index, 18],
#                     avg_cost[index, 19], avg_cost[index, 20], avg_cost[index, 21], avg_cost[index, 22], avg_cost[index, 23], epoch_end_time-epoch_start_time))
#     end_time = time.time()
#     print("Training time: ", end_time - start_time)

"""
=========== Universal Single-task Trainer ===========
"""

def single_task_trainer(train_loader, test_loader, single_task_model, optimizer, scheduler, opt, total_epoch=200):
    train_batch = len(train_loader)
    test_batch = len(test_loader)
    avg_cost = np.zeros([total_epoch, 24], dtype=np.float32)
    task_id = opt.task_id

    define_wandb_metric_single_task(task_id)

    for index in range(total_epoch):
        cost = np.zeros(24, dtype=np.float32)

        # iteration for all batches
        single_task_model.train()
        conf_mat = ConfMatrix(single_task_model.class_nb)

        for k, data in enumerate(train_loader):
            train_data, train_label, train_depth = data
            train_data, train_label = train_data.cuda(), train_label.long().cuda()
            train_depth = train_depth.cuda()

            train_pred, logsigma = single_task_model(train_data)

            optimizer.zero_grad()
            if task_id == 0:
                train_loss = model_fit(train_pred[0], train_label, 'semantic')
                train_loss.backward()
                optimizer.step()

                conf_mat.update(train_pred[0].argmax(1).flatten().cpu(), train_label.flatten().cpu())
                cost[0] = train_loss.item()

            if task_id == 1:
                train_loss = model_fit(train_pred[1], train_depth, 'depth')
                train_loss.backward()
                optimizer.step()

                cost[3] = train_loss.item()
                cost[4], cost[5] = depth_error(train_pred[1], train_depth)

            avg_cost[index, :6] += cost[:6] / train_batch

        if task_id == 0:
            avg_cost[index, 1:3] = conf_mat.get_metrics()


        # evaluating test data
        single_task_model.eval()
        conf_mat = ConfMatrix(single_task_model.class_nb)
        if (index + 1) % opt.eval_freq == 0 or index == total_epoch - 1:
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    test_data, test_label, test_depth = data
                    test_data, test_label = test_data.cuda(), test_label.long().cuda()
                    test_depth = test_depth.cuda()

                    test_pred, logsigma = single_task_model(test_data)

                    if task_id == 0:
                        test_loss = model_fit(test_pred[0], test_label, 'semantic')

                        conf_mat.update(test_pred[0].argmax(1).flatten().cpu(), test_label.flatten().cpu())
                        cost[6] = test_loss.item()

                    if task_id == 1:
                        test_loss = model_fit(test_pred[1], test_depth, 'depth')

                        cost[9] = test_loss.item()
                        cost[10], cost[11] = depth_error(test_pred[1], test_depth)

                    avg_cost[index, 6:] += cost[6:] / test_batch

                if task_id == 0:
                    avg_cost[index, 7:9] = conf_mat.get_metrics()

            print_details_single_tasks(avg_cost=avg_cost, index=index, task_id=task_id, mode='val')
        else:
            print_details_single_tasks(avg_cost=avg_cost, index=index, task_id=task_id, mode='train')

        scheduler.step()


''' ===== multi task MGD trainer ==== '''
def multi_task_mgd_trainer(train_loader, test_loader, multi_task_model,
                           optimizer, scheduler, opt,
                           epoch=200, method='sumloss', alpha=0.5, seed=0, degree=0.5, eval_freq=4):

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def graddrop(grads):
        P = 0.5 * (1. + grads.sum(1) / (grads.abs().sum(1) + 1e-8))
        U = torch.rand_like(grads[:, 0])
        M = P.gt(U).view(-1, 1) * grads.gt(0) + P.lt(U).view(-1, 1) * grads.lt(0)
        g = (grads * M.float()).mean(1)
        return g

    def mgd(grads):
        grads_cpu = grads.t().cpu()
        sol, min_norm = MinNormSolver.find_min_norm_element([grads_cpu[t] for t in range(grads.shape[-1])])
        w = torch.FloatTensor(sol).to(grads.device)
        g = grads.mm(w.view(-1, 1)).view(-1)
        return g

    def pcgrad(grads, rng):
        grad_vec = grads.t()
        num_tasks = 2

        shuffled_task_indices = np.zeros((num_tasks, num_tasks - 1), dtype=int)
        for i in range(num_tasks):
            task_indices = np.arange(num_tasks)
            task_indices[i] = task_indices[-1]
            shuffled_task_indices[i] = task_indices[:-1]
            rng.shuffle(shuffled_task_indices[i])
        shuffled_task_indices = shuffled_task_indices.T

        normalized_grad_vec = grad_vec / (
                grad_vec.norm(dim=1, keepdim=True) + 1e-8
        )  # num_tasks x dim
        modified_grad_vec = deepcopy(grad_vec)
        for task_indices in shuffled_task_indices:
            normalized_shuffled_grad = normalized_grad_vec[
                task_indices
            ]  # num_tasks x dim
            dot = (modified_grad_vec * normalized_shuffled_grad).sum(
                dim=1, keepdim=True
            )  # num_tasks x dim
            modified_grad_vec -= torch.clamp_max(dot, 0) * normalized_shuffled_grad
        g = modified_grad_vec.mean(dim=0)
        return g

    def cagrad(grads, alpha=0.5, rescale=1, n_tasks=2):
        GG = grads.t().mm(grads).cpu()  # [num_tasks, num_tasks]
        g0_norm = (GG.mean() + 1e-8).sqrt()  # norm of the average gradient

        x_start = np.ones(n_tasks) / n_tasks
        bnds = tuple((0, 1) for x in x_start)
        cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})
        A = GG.numpy()
        b = x_start.copy()
        c = (alpha * g0_norm + 1e-8).item()

        def objfn(x):
            return (x.reshape(1, n_tasks).dot(A).dot(b.reshape(n_tasks, 1)) + c * np.sqrt(
                x.reshape(1, n_tasks).dot(A).dot(x.reshape(n_tasks, 1)) + 1e-8)).sum()

        res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
        w_cpu = res.x
        ww = torch.Tensor(w_cpu).to(grads.device)
        gw = (grads * ww.view(1, -1)).sum(1)
        gw_norm = gw.norm()
        lmbda = c / (gw_norm + 1e-8)
        g = grads.mean(1) + lmbda * gw
        if rescale == 0:
            return g
        elif rescale == 1:
            return g / (1 + alpha ** 2)
        else:
            return g / (1 + alpha)

    def grad2vec(m, grads, grad_dims, task):
        # store the gradients
        grads[:, task].fill_(0.0)
        cnt = 0
        for name, p in m.shared_parameters().items():
            grad = p.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone()
                beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
                en = sum(grad_dims[:cnt + 1])
                grads[beg:en, task].copy_(grad_cur.data.view(-1))
            cnt += 1

    def overwrite_grad(m, newgrad, grad_dims, n_tasks=2):
        newgrad = newgrad * n_tasks  # to match the sum loss
        cnt = 0
        for name, param in m.shared_parameters().items():
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            this_grad = newgrad[beg: en].contiguous().view(param.data.size())
            param.grad = this_grad.data.clone()
            cnt += 1

    def grad2vec_list(m):
        grad_list = []
        for name, param in m.shared_parameters().items():
            grad = param.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone().view(-1)
                grad_list.append(grad_cur)
        return grad_list

    def get_layer_dict(m):
        shared_parameters = m.shared_parameters()
        name_list = list(shared_parameters.keys())

        layer_dict = {}
        for i, name in enumerate(name_list):
            if '.weight' in name:
                name = name.replace('.weight', '')
            elif '.bias' in name:
                name = name.replace('.bias', '')

            if name not in layer_dict:
                layer_dict[name] = [i]
            else:
                layer_dict[name].append(i)

        return layer_dict

    def split_layer(grad_list, name_dict):
        grad_new = []
        for key, value in name_dict.items():
            grad = [grad_list[i] for i in value]
            grad = torch.cat(grad)
            grad_new.append(grad)

        return grad_new

    def reshape_grad(g, grad_dims):
        grad_all = []
        for i in range(len(grad_dims)):
            beg = 0 if i == 0 else sum(grad_dims[:i])
            en = sum(grad_dims[:i+1])
            this_grad = g[beg:en].data.detach().clone()
            grad_all.append(this_grad)
        return grad_all

    define_wandb_metric()
    rng = np.random.default_rng()

    # rng = np.random.default_rng()
    grad_dims = []
    for key, param in multi_task_model.shared_parameters().items():
        grad_dims.append(param.data.numel())

    # grads = torch.Tensor(sum(grad_dims), 3).cuda()

    layer_dict = get_layer_dict(multi_task_model)
    layer_name = list(layer_dict.keys())

    layer_wise_angle = OrderedDict()
    layer_wise_dot_value = OrderedDict()
    layer_wise_task_angle = OrderedDict()
    layer_wise_task_dot_value = OrderedDict()
    task_diff = OrderedDict()
    diff = OrderedDict()

    for name in layer_name:
        layer_wise_angle[name] = []
        layer_wise_dot_value[name] = []
        layer_wise_task_angle[name] = []
        layer_wise_task_dot_value[name] = []
        task_diff[name] = []
        diff[name] = []

    train_batch = len(train_loader)
    test_batch = len(test_loader)
    T = opt.temp
    avg_cost = np.zeros([epoch, 12], dtype=np.float32)
    lambda_weight = np.ones([2, epoch])

    timer = Timer()
    inner_index = 0

    angle_all = []

    for index in range(epoch):

        avg_t1_angle = Averager()
        avg_t2_angle = Averager()

        cost = np.zeros(12, dtype=np.float32)

        # apply Dynamic Weight Average
        if opt.weight == 'dwa':
            if index == 0 or index == 1:
                lambda_weight[:, index] = 1.0
            else:
                w_1 = avg_cost[index - 1, 0] / avg_cost[index - 2, 0]
                w_2 = avg_cost[index - 1, 3] / avg_cost[index - 2, 3]
                lambda_weight[0, index] = 2 * np.exp(w_1 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T))
                lambda_weight[1, index] = 2 * np.exp(w_2 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T))

        # iteration for all batches
        multi_task_model.train()
        conf_mat = ConfMatrix(multi_task_model.class_nb)

        for k, data in enumerate(train_loader):
            grads = torch.Tensor(sum(grad_dims), 2).cuda()
            train_data, train_label, train_depth = data
            train_data, train_label = train_data.cuda(), train_label.long().cuda()
            train_depth = train_depth.cuda()

            train_pred, logsigma = multi_task_model(train_data)

            if opt.top_q <= 0:
                train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
                              model_fit(train_pred[1], train_depth, 'depth')]
            else:
                train_loss = [model_fit_top_q(train_pred[0], train_label, 'semantic', top_q=opt.top_q, std_b=opt.std),
                              model_fit_top_q(train_pred[1], train_depth, 'depth', top_q=opt.top_q, std_b=opt.std)]

            train_loss_tmp = [0, 0]

            if opt.weight == 'equal' or opt.weight == 'dwa':
                for i in range(2):
                    # train_loss_tmp[i] = train_loss[i] * lambda_weight[i, index]
                    train_loss_tmp[i] = train_loss[i]
            else:
                for i in range(2):
                    train_loss_tmp[i] = 1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2

            optimizer.zero_grad()
            if method == "graddrop":
                n_tasks = 2
                for i in range(n_tasks):
                    if i < n_tasks:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                angle = cos(g1, g2)
                angle_all.append(angle)

                g = graddrop(grads)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            elif method == "mgd":
                n_tasks = 2
                for i in range(n_tasks):
                    if i < n_tasks:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                angle = cos(g1, g2)
                angle_all.append(angle)

                g = mgd(grads)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            elif method == "pcgrad":
                n_tasks = 2
                for i in range(n_tasks):
                    if i < n_tasks:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                angle = cos(g1, g2)
                angle_all.append(angle)

                g = pcgrad(grads, rng)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g

            elif method == "cagrad":
                for i in range(2):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()

                angle = cos(g1, g2)
                angle_all.append(angle)

                g = cagrad(grads, alpha, rescale=1)

                t1_angle = cos(g, g1)
                t2_angle = cos(g, g2)

                log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item()}
                wandb.log(log_dict)

                avg_t1_angle.add(t1_angle.item())
                avg_t2_angle.add(t2_angle.item())

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
                del g1
                del g2

            elif method == 'branch_layer':
                grad_all = []
                n_tasks = 2
                for i in range(n_tasks):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    grad = grad2vec_list(multi_task_model)
                    grad = split_layer(grad_list=grad, name_dict=layer_dict)
                    grad_all.append(grad)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()

                if opt.method_sub == 'nothing':
                    g = (g1 + g2) / 2.0
                elif opt.method_sub == 'cagrad':
                    g = cagrad(grads, alpha, rescale=1)
                else:
                    raise ValueError(f'Error: {opt.method_sub}')

                t1_angle = cos(g, g1)
                t2_angle = cos(g, g2)

                log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item()}

                wandb.log(log_dict)
                avg_t1_angle.add(t1_angle.item())
                avg_t2_angle.add(t2_angle.item())

                target_g_list = reshape_grad(g, grad_dims)
                target_g_list = split_layer(target_g_list, name_dict=layer_dict)

                length = len(grad_all[0])

                pair_grad = []
                for i in range(length):
                    temp = []
                    for j in range(n_tasks):
                        temp.append(grad_all[j][i])
                    temp = torch.stack(temp)
                    pair_grad.append(temp)

                # cos_saved = []

                # for i in range(length):
                #     task_diff_list = LFT_task_cos_diff(grad_all=grad_all, grad_target=target_g_list, n_tasks=n_tasks,
                #                                    pos=i)
                #     diff_list = LFT_cos_diff(grad_all=grad_all, n_tasks=n_tasks, pos=i)
                #     task_diff[layer_name[i]].append(task_diff_list)
                #     diff[layer_name[i]].append(diff_list)

                for i, pair in enumerate(pair_grad):
                    layer_wise_task_cos = pair_cos_with_d(pair, target_g_list[i]).cpu()
                    # layer_wise_dot = pair_dot_with_d(pair, target_g_list[i]).cpu()
                    layer_wise_cos = pair_cos(pair).cpu()
                    # layer_wise_task_dot = pair_dot(pair).cpu()

                    layer_wise_angle[layer_name[i]].append(layer_wise_cos)
                    # layer_wise_dot_value[layer_name[i]].append(layer_wise_dot)
                    layer_wise_task_angle[layer_name[i]].append(layer_wise_task_cos)
                    # layer_wise_task_dot_value[layer_name[i]].append(layer_wise_task_dot)

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()

                # if index == opt.start_epoch and turning_cnt <= 0:
                #     turning_cnt += 1
                #     branch_layers = topK_prob(task_diff, topK=opt.topK)
                #     # branch_layers_2 = topK_value(task_diff, topK=opt.topK)
                #     multi_task_model.turn(task_branches=branch_layers)
                #
                #     grad_dims = []
                #     for key, param in multi_task_model.shared_parameters().items():
                #         grad_dims.append(param.data.numel())
                #
                #     layer_dict = get_layer_dict(multi_task_model)
                #     layer_name = list(layer_dict.keys())

                del g
                del g1
                del g2
            elif method == 'nothing_v2':
                for i in range(2):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()

                angle = cos(g1, g2)
                angle_all.append(angle)

                g = 0.5 * g1 + 0.5 * g2

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()

            elif method == 'nothing':
                # n_tasks = 2

                # for i in range(n_tasks):
                #     train_loss_tmp[i].backward(retain_graph=True)
                #     grad2vec(multi_task_model, grads, grad_dims, i)
                #     multi_task_model.zero_grad_shared_modules()
                #
                # g1 = grads[:, 0].data.clone()
                # g2 = grads[:, 1].data.clone()
                #
                # g = grads.mean(dim=1).data.clone()
                #
                # t1_angle = cos(g, g1)
                # t2_angle = cos(g, g2)
                #
                # log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item()}
                # wandb.log(log_dict)
                #
                # avg_t1_angle.add(t1_angle.item())
                # avg_t2_angle.add(t2_angle.item())

                optimizer.zero_grad()
                loss = sum(train_loss_tmp)

                loss.backward()
                optimizer.step()

            # accumulate label prediction for every pixel in training images
            conf_mat.update(train_pred[0].argmax(1).flatten().cpu(), train_label.flatten().cpu())

            cost[0] = train_loss[0].item()
            cost[3] = train_loss[1].item()
            cost[4], cost[5] = depth_error(train_pred[1], train_depth)
            avg_cost[index, :6] += cost[:6] / train_batch


            del data
            del logsigma
            del train_data
            del train_depth
            del train_label
            del train_loss
            del train_loss_tmp
            del train_pred
            del grads
            inner_index += 1
            # torch.use_deterministic_algorithms(True)

        # compute mIoU and acc
        avg_cost[index, 1:3] = conf_mat.get_metrics()

        if (index + 1) % eval_freq == 0  or index == epoch - 1:
            # evaluating test data
            multi_task_model.eval()
            conf_mat = ConfMatrix(multi_task_model.class_nb)
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    test_data, test_label, test_depth = data
                    test_data, test_label = test_data.cuda(), test_label.long().cuda()
                    test_depth = test_depth.cuda()

                    test_pred, _ = multi_task_model(test_data)
                    test_loss = [model_fit(test_pred[0], test_label, 'semantic'),
                                 model_fit(test_pred[1], test_depth, 'depth')]

                    conf_mat.update(test_pred[0].argmax(1).flatten().cpu(), test_label.flatten().cpu())

                    cost[6] = test_loss[0].item()
                    cost[9] = test_loss[1].item()
                    cost[10], cost[11] = depth_error(test_pred[1], test_depth)
                    avg_cost[index, 6:] += cost[6:] / test_batch

                # compute mIoU and acc
                avg_cost[index, 7:9] = conf_mat.get_metrics()


            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle, mode='val')
        else:
            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle, mode='train')


        print(f'Epoch:{index} | ETA:{timer.measure()}/{timer.measure((index + 1) / epoch)} | lr: {scheduler.get_lr()}')

        scheduler.step()

        if (index + 1) == epoch and method == 'branch_layer':
            # saved_dict = {'task_diff': task_diff}
            # torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{alpha}_{epoch}_task_diff_data_cityspace.pt')
            # saved_dict = {'diff': diff}
            # torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{alpha}_{epoch}_diff_data_cityspace.pt')
            saved_dict = {'cos': layer_wise_angle}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{alpha}_{epoch}_lw_cos_cityspace.pt')
            saved_dict = {'task_cos': layer_wise_task_angle}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{alpha}_{epoch}_lw_task_cos_cityspace.pt')

        if (index + 1) == epoch:
            if opt.base_model == 'fw_ablation':
                torch.save(angle_all,
                           f'./saved/{opt.seed}{opt.base_model}{opt.flag}_{opt.topK}{opt.method}{opt.optimizer}_{epoch}_angle.pt')
            else:
                torch.save(angle_all, f'./saved/{opt.seed}{opt.base_model}{opt.topK}{opt.method}{opt.optimizer}_{epoch}_angle.pt')

def train_rotograd(train_loader, test_loader, model,
                           optimizer_main, scheduler_main, optimizer_R, scheduler_R, opt):

    epoch = opt.n_epoch
    train_batch = len(train_loader)
    test_batch = len(test_loader)
    avg_cost = np.zeros([epoch, 12], dtype=np.float32)

    timer = Timer()
    avg_t1_angle = Averager()
    avg_t2_angle = Averager()
    for index in range(epoch):
        cost = np.zeros(12, dtype=np.float32)

        # iteration for all batches
        model.train()
        conf_mat = ConfMatrix(7)

        for k, data in enumerate(train_loader):
            train_data, train_label, train_depth = data
            train_data, train_label = train_data.cuda(), train_label.long().cuda()
            train_depth = train_depth.cuda()

            train_pred = model(train_data)

            train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
                          model_fit(train_pred[1], train_depth, 'depth')]

            optimizer_main.zero_grad()
            optimizer_R.zero_grad()

            model.backward(train_loss)

            optimizer_main.step()
            optimizer_R.step()


            conf_mat.update(train_pred[0].argmax(1).flatten().cpu(), train_label.flatten().cpu())

            cost[0] = train_loss[0].item()
            cost[3] = train_loss[1].item()
            cost[4], cost[5] = depth_error(train_pred[1], train_depth)
            avg_cost[index, :6] += cost[:6] / train_batch

        if (index + 1) % opt.eval_freq == 0 or index == epoch - 1:
            # evaluating test data
            model.eval()
            conf_mat = ConfMatrix(7)
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    test_data, test_label, test_depth = data
                    test_data, test_label = test_data.cuda(), test_label.long().cuda()
                    test_depth = test_depth.cuda()

                    test_pred = model(test_data)
                    test_loss = [model_fit(test_pred[0], test_label, 'semantic'),
                                 model_fit(test_pred[1], test_depth, 'depth')]

                    conf_mat.update(test_pred[0].argmax(1).flatten().cpu(), test_label.flatten().cpu())

                    cost[6] = test_loss[0].item()
                    cost[9] = test_loss[1].item()
                    cost[10], cost[11] = depth_error(test_pred[1], test_depth)
                    avg_cost[index, 6:] += cost[6:] / test_batch

                # compute mIoU and acc
                avg_cost[index, 7:9] = conf_mat.get_metrics()

            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle,
                          mode='val')
        else:
            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle,
                          mode='train')

        print(f'Epoch:{index} | ETA:{timer.measure()}/{timer.measure((index + 1) / epoch)}')

        scheduler_main.step()
        scheduler_R.step()