import math

import numpy as np
import time
import torch
import torch.nn.functional as F
import gc
import random
import wandb

from sklearn import random_projection

from copy import deepcopy
from min_norm_solvers import MinNormSolver
from scipy.optimize import minimize, Bounds, minimize_scalar
import  matplotlib
from collections import OrderedDict
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from random_projection import random_projection_matrix

"""
Define task metrics, loss functions and model trainer here.
"""
def get_variable_name(variable, loc):
    for k, v in loc.items():
        if loc[k] == variable:
            return k

class Averager():

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.data = []

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def add(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.data.append(val)

    def item(self):
        return self.avg

    def obtain_data(self):
        return self.data

    def __len__(self):
        return len(self.data)

class Running_Avg():
    def __init__(self, max_cnt):
        self.max_cnt = max_cnt
        self.data = []
        self.cnt = 0

    def add(self, val):
        self.data.append(val)
        self.cnt += 1

    def item(self):
        if self.cnt < self.max_cnt:
            d = self.data
        else:
            d = self.data[-self.max_cnt:]
        return sum(d) / len(d)

    def std(self):
        if self.cnt < self.max_cnt:
            d = self.data
        else:
            d = self.data[-self.max_cnt:]

        return np.std(d)

class Timer():
    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return f'{x / 3600:.1f}h'
        if x >= 60:
            return f'{round(x / 60)}m'
        return f'{x}s'

class ConfMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, pred, target):
        n = self.num_classes
        if self.mat is None:
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
        with torch.no_grad():
            k = (target >= 0) & (target < n)
            inds = n * target[k].to(torch.int64) + pred[k]
            self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

    def get_metrics(self):
        h = self.mat.float()
        acc = torch.diag(h).sum() / h.sum()
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return torch.mean(iu), acc

def cos(t1, t2):
    t1 = F.normalize(t1, dim=0)
    t2 = F.normalize(t2, dim=0)

    dot = (t1 * t2).sum(dim=0)

    return dot

def dot(t1, t2):
    dot = (t1 * t2).sum(dim=0)

    return dot

def pair_cos_with_d(pair, d):
    length = pair.size(0)

    dot_value = [cos(d, pair[i]) for i in range(length)]

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_dot_with_d(pair, d):
    length = pair.size(0)

    dot_value = [dot(d, pair[i]) for i in range(length)]

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_cos(pair):
    length = pair.size(0)

    dot_value = []
    for i in range(length - 1):
        for j in range(i + 1, length):
           dot_value.append(cos(pair[i], pair[j]))

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_dot(pair):
    length = pair.size(0)

    dot_value = []
    for i in range(length - 1):
        for j in range(i + 1, length):
            dot_value.append(dot(pair[i], pair[j]))

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def remove_layer_gradient(grad_list, pos):
    if not isinstance(pos, list):
        pos = [pos]
    g_list = []
    for i, g in enumerate(grad_list):
        if i not in pos:
            g_list.append(g)

    g = torch.cat(g_list)
    return g

def LFT_min_cos(grad_all, grad_target, n_tasks, pos):
    min_cos = 1000

    g_target = remove_layer_gradient(grad_target, pos)
    for i in range(n_tasks):
        g_task = remove_layer_gradient(grad_all[i], pos)
        cos_v = cos(g_task, g_target)
        if cos_v < min_cos:
            min_cos = cos_v

    return min_cos

def LFT_task_cos_diff(grad_all, grad_target, n_tasks, pos):
    diff_list = []
    for i in range(n_tasks):
        g_task_previous = torch.cat(grad_all[i])
        g_previous = torch.cat(grad_target)
        cos_v_previous = cos(g_task_previous, g_previous)

        g_task = remove_layer_gradient(grad_all[i], pos)
        g = remove_layer_gradient(grad_target, pos)
        cos_v = cos(g_task, g)

        diff = cos_v - cos_v_previous
        diff_list.append(diff)

    diff_list = torch.stack(diff_list).view(-1)

    return diff_list

def LFT_cos_diff(grad_all, n_tasks, pos):
    diff_list = []
    for i in range(n_tasks - 1):
        for j in range(i + 1, n_tasks):
            g_1 = torch.cat(grad_all[i])
            g_2 = torch.cat(grad_all[j])
            cos_v_previous = cos(g_1, g_2)

            g_1 = remove_layer_gradient(grad_all[i], pos)
            g_2 = remove_layer_gradient(grad_all[j], pos)

            cos_v = cos(g_1, g_2)

            diff = cos_v - cos_v_previous
            diff_list.append(diff)

    diff_list = torch.stack(diff_list).view(-1)

    return diff_list

def topK_prob(task_diff, topK):
    diff_dict = OrderedDict()
    for key, value in task_diff.items():
        l = torch.stack(value)
        n_task = l.size(1)
        gt = (l > 0)
        gt = gt.sum(dim=1)
        gt = (gt == n_task)
        prob = gt.sum() / gt.size(0)
        diff_dict[key] = prob

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1], reverse=True)}

    output = list(diff_dict.keys())[:topK]
    #
    # print('-------------topK_prob---------------')
    # print(diff_dict)

    return output

def topK_value(task_diff, topK):
    diff_dict = OrderedDict()

    for key, value in task_diff.items():
        l = torch.stack(value)
        score = l.mean()
        diff_dict[key] = score

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1])}

    output = list(diff_dict.keys())[:topK]

    # print('-------------topK_value---------------')
    # print(diff_dict)

    return output

def avg_tensor(list, dim=0):
    t = torch.stack(list)
    mean = torch.mean(t, dim=dim)
    return mean

def quantile_topK(dict, q, topK, interpolation='linear'):
    q_angle = OrderedDict()
    for key, value in dict.items():
        mean = avg_tensor(value, dim=1)
        quantile_value = torch.quantile(mean, q, interpolation=interpolation)
        q_angle[key] = quantile_value

    q_angle = {k: v for k, v in sorted(q_angle.items(), key=lambda item: item[1])}

    output = list(q_angle.keys())[:topK]

    return output

def define_wandb_metric():
    wandb.define_metric("epoch")
    wandb.define_metric("Train_SegLoss", step_metric="epoch")
    wandb.define_metric("Train_mIoU", step_metric="epoch")
    wandb.define_metric("Train_Pix_Acc", step_metric="epoch")
    wandb.define_metric("Train_DepthLoss", step_metric="epoch")
    wandb.define_metric("Train_Abs_Err", step_metric="epoch")
    wandb.define_metric("Train_Rel_Err", step_metric="epoch")
    wandb.define_metric("Train_NormalLoss", step_metric="epoch")
    wandb.define_metric("Train_AD_mean", step_metric="epoch")
    wandb.define_metric("Train_AD_median", step_metric="epoch")
    wandb.define_metric("Train_WithinT_11d25", step_metric="epoch")
    wandb.define_metric("Train_WithinT_22d5", step_metric="epoch")
    wandb.define_metric("Train_WithinT_30", step_metric="epoch")

    wandb.define_metric("avg_t1_angle", step_metric="epoch")
    wandb.define_metric("avg_t2_angle", step_metric="epoch")
    wandb.define_metric("avg_t3_angle", step_metric="epoch")

    wandb.define_metric("Test_SegLoss", step_metric="epoch")
    wandb.define_metric("Test_mIoU", step_metric="epoch")
    wandb.define_metric("Test_Pix_Acc", step_metric="epoch")
    wandb.define_metric("Test_DepthLoss", step_metric="epoch")
    wandb.define_metric("Test_Abs_Err", step_metric="epoch")
    wandb.define_metric("Test_Rel_Err", step_metric="epoch")
    wandb.define_metric("Test_NormalLoss", step_metric="epoch")
    wandb.define_metric("Test_AD_mean", step_metric="epoch")
    wandb.define_metric("Test_AD_median", step_metric="epoch")
    wandb.define_metric("Test_WithinT_11d25", step_metric="epoch")
    wandb.define_metric("Test_WithinT_22d5", step_metric="epoch")
    wandb.define_metric("Test_WithinT_30", step_metric="epoch")

    wandb.define_metric('inner_index')
    wandb.define_metric('t1_angle', step_metric='inner_index')
    wandb.define_metric('t2_angle', step_metric='inner_index')
    wandb.define_metric('t3_angle', step_metric='inner_index')
    wandb.define_metric('t1_opposite', step_metric='inner_index')
    wandb.define_metric('t2_opposite', step_metric='inner_index')
    wandb.define_metric('t3_opposite', step_metric='inner_index')

def define_wandb_metric_single_task(task_id):
    wandb.define_metric("epoch")

    if task_id == 0:
        wandb.define_metric("Train_SegLoss", step_metric="epoch")
        wandb.define_metric("Train_mIoU", step_metric="epoch")
        wandb.define_metric("Train_Pix_Acc", step_metric="epoch")

        wandb.define_metric("Test_SegLoss", step_metric="epoch")
        wandb.define_metric("Test_mIoU", step_metric="epoch")
        wandb.define_metric("Test_Pix_Acc", step_metric="epoch")

    if task_id == 1:
        wandb.define_metric("Train_DepthLoss", step_metric="epoch")
        wandb.define_metric("Train_Abs_Err", step_metric="epoch")
        wandb.define_metric("Train_Rel_Err", step_metric="epoch")

        wandb.define_metric("Test_DepthLoss", step_metric="epoch")
        wandb.define_metric("Test_Abs_Err", step_metric="epoch")
        wandb.define_metric("Test_Rel_Err", step_metric="epoch")

    if task_id == 2:
        wandb.define_metric("Train_NormalLoss", step_metric="epoch")
        wandb.define_metric("Train_AD_mean", step_metric="epoch")
        wandb.define_metric("Train_AD_median", step_metric="epoch")
        wandb.define_metric("Train_WithinT_11d25", step_metric="epoch")
        wandb.define_metric("Train_WithinT_22d5", step_metric="epoch")
        wandb.define_metric("Train_WithinT_30", step_metric="epoch")

        wandb.define_metric("Test_NormalLoss", step_metric="epoch")
        wandb.define_metric("Test_AD_mean", step_metric="epoch")
        wandb.define_metric("Test_AD_median", step_metric="epoch")
        wandb.define_metric("Test_WithinT_11d25", step_metric="epoch")
        wandb.define_metric("Test_WithinT_22d5", step_metric="epoch")
        wandb.define_metric("Test_WithinT_30", step_metric="epoch")

def print_details_single_tasks(avg_cost, index, task_id, mode='train'):
    if task_id == 0:
        if mode == 'train':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f}'
                  .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2]))

            log_dict = {
                'epoch': index,
                'Train_SegLoss': avg_cost[index, 0],
                'Train_mIoU': avg_cost[index, 1],
                'Train_Pix_Acc': avg_cost[index, 2]
            }

            wandb.log(log_dict)

        elif mode == 'val':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} TEST: {:.4f} {:.4f} {:.4f}'
              .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 12],
                      avg_cost[index, 13], avg_cost[index, 14]))

            log_dict = {
                'epoch': index,
                'Train_SegLoss': avg_cost[index, 0],
                'Train_mIoU': avg_cost[index, 1],
                'Train_Pix_Acc': avg_cost[index, 2],
                'Test_SegLoss': avg_cost[index, 12],
                'Test_mIoU': avg_cost[index, 13],
                'Test_Pix_Acc': avg_cost[index, 14]
            }

            wandb.log(log_dict)

    elif task_id == 1:
        if mode == 'train':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f}'
                  .format(index, avg_cost[index, 3], avg_cost[index, 4], avg_cost[index, 5]))

            log_dict = {
                'epoch': index,
                'Train_DepthLoss': avg_cost[index, 3],
                'Train_Abs_Err': avg_cost[index, 4],
                'Train_Rel_Err': avg_cost[index, 5]
            }

            wandb.log(log_dict)

        elif mode == 'val':
            print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} TEST: {:.4f} {:.4f} {:.4f}'
              .format(index, avg_cost[index, 3], avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 15],
                      avg_cost[index, 16], avg_cost[index, 17]))

            log_dict = {
                'epoch': index,
                'Train_DepthLoss': avg_cost[index, 3],
                'Train_Abs_Err': avg_cost[index, 4],
                'Train_Rel_Err': avg_cost[index, 5],
                'Test_DepthLoss': avg_cost[index, 15],
                'Test_Abs_Err': avg_cost[index, 16],
                'Test_Rel_Err': avg_cost[index, 17]
            }

            wandb.log(log_dict)

    elif task_id == 2:
        if mode == 'train':
            print(
                'Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'
                .format(index, avg_cost[index, 6], avg_cost[index, 7], avg_cost[index, 8], avg_cost[index, 9],
                        avg_cost[index, 10], avg_cost[index, 11]))

            log_dict = {
                'epoch': index,
                'Train_NormalLoss': avg_cost[index, 6],
                'Train_AD_mean': avg_cost[index, 7],
                'Train_AD_median': avg_cost[index, 8],
                'Train_WithinT_11d25': avg_cost[index, 9],
                'Train_WithinT_22d5': avg_cost[index, 10],
                'Train_WithinT_30': avg_cost[index, 11]
            }

            wandb.log(log_dict)

        elif mode == 'val':
            print(
                'Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} TEST: {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'
                .format(index, avg_cost[index, 6], avg_cost[index, 7], avg_cost[index, 8], avg_cost[index, 9],
                        avg_cost[index, 10], avg_cost[index, 11],
                        avg_cost[index, 18], avg_cost[index, 19], avg_cost[index, 20], avg_cost[index, 21],
                        avg_cost[index, 22], avg_cost[index, 23]))

            log_dict = {
                'epoch': index,
                'Train_NormalLoss': avg_cost[index, 6],
                'Train_AD_mean': avg_cost[index, 7],
                'Train_AD_median': avg_cost[index, 8],
                'Train_WithinT_11d25': avg_cost[index, 9],
                'Train_WithinT_22d5': avg_cost[index, 10],
                'Train_WithinT_30': avg_cost[index, 11],
                'Test_NormalLoss': avg_cost[index, 18],
                'Test_AD_mean': avg_cost[index, 19],
                'Test_AD_median': avg_cost[index, 20],
                'Test_WithinT_11d25': avg_cost[index, 21],
                'Test_WithinT_22d5': avg_cost[index, 22],
                'Test_WithinT_30': avg_cost[index, 23]
            }

            wandb.log(log_dict)

    else:
        raise ValueError('Error')

def print_details(avg_cost, index, avg_t1_angle, avg_t2_angle, avg_t3_angle, mode='train'):
    if mode == 'train':
        print(
            'Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'
                .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 3],
                        avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 6], avg_cost[index, 7],
                        avg_cost[index, 8],
                        avg_cost[index, 9], avg_cost[index, 10], avg_cost[index, 11]))

        log_dict = {
            'epoch': index,
            'Train_SegLoss': avg_cost[index, 0],
            'Train_mIoU': avg_cost[index, 1],
            'Train_Pix_Acc': avg_cost[index, 2],
            'Train_DepthLoss': avg_cost[index, 3],
            'Train_Abs_Err': avg_cost[index, 4],
            'Train_Rel_Err': avg_cost[index, 5],
            'Train_NormalLoss': avg_cost[index, 6],
            'Train_AD_mean': avg_cost[index, 7],
            'Train_AD_median': avg_cost[index, 8],
            'Train_WithinT_11d25': avg_cost[index, 9],
            'Train_WithinT_22d5': avg_cost[index, 10],
            'Train_WithinT_30': avg_cost[index, 11],
            'avg_t1_angle': avg_t1_angle.item(),
            'avg_t2_angle': avg_t2_angle.item(),
            'avg_t3_angle': avg_t3_angle.item()
        }

        wandb.log(log_dict)

    elif mode == 'val':
        print(
            'Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} ||'
            'TEST: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'
                .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 3],
                        avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 6], avg_cost[index, 7],
                        avg_cost[index, 8],
                        avg_cost[index, 9], avg_cost[index, 10], avg_cost[index, 11], avg_cost[index, 12],
                        avg_cost[index, 13] * 100, avg_cost[index, 14] * 100, avg_cost[index, 15], avg_cost[index, 16],
                        avg_cost[index, 17], avg_cost[index, 18],
                        avg_cost[index, 19], avg_cost[index, 20], avg_cost[index, 21] * 100, avg_cost[index, 22] * 100,
                        avg_cost[index, 23] * 100))

        log_dict = {
            'epoch': index,
            'Train_SegLoss': avg_cost[index, 0],
            'Train_mIoU': avg_cost[index, 1],
            'Train_Pix_Acc': avg_cost[index, 2],
            'Train_DepthLoss': avg_cost[index, 3],
            'Train_Abs_Err': avg_cost[index, 4],
            'Train_Rel_Err': avg_cost[index, 5],
            'Train_NormalLoss': avg_cost[index, 6],
            'Train_AD_mean': avg_cost[index, 7],
            'Train_AD_median': avg_cost[index, 8],
            'Train_WithinT_11d25': avg_cost[index, 9],
            'Train_WithinT_22d5': avg_cost[index, 10],
            'Train_WithinT_30': avg_cost[index, 11],
            'Test_SegLoss': avg_cost[index, 12],
            'Test_mIoU': avg_cost[index, 13],
            'Test_Pix_Acc': avg_cost[index, 14],
            'Test_DepthLoss': avg_cost[index, 15],
            'Test_Abs_Err': avg_cost[index, 16],
            'Test_Rel_Err': avg_cost[index, 17],
            'Test_NormalLoss': avg_cost[index, 18],
            'Test_AD_mean': avg_cost[index, 19],
            'Test_AD_median': avg_cost[index, 20],
            'Test_WithinT_11d25': avg_cost[index, 21],
            'Test_WithinT_22d5': avg_cost[index, 22],
            'Test_WithinT_30': avg_cost[index, 23],
            'avg_t1_angle': avg_t1_angle.item(),
            'avg_t2_angle': avg_t2_angle.item(),
            'avg_t3_angle': avg_t3_angle.item()
        }

        wandb.log(log_dict)

    else:
        raise ValueError('Error')

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def model_fit(x_pred, x_output, task_type):
    # binary mark to mask out undefined pixel space
    binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).cuda()

    if task_type == 'semantic':
        # semantic loss: depth-wise cross entropy

        # x_pred = F.log_softmax(x_pred, dim=1)
        # torch.use_deterministic_algorithms(False)
        # loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1)
        # loss = loss_func(x_pred, x_output)

        x_pred = F.log_softmax(x_pred, dim=1)
        loss = F.nll_loss(x_pred, x_output, ignore_index=-1)
        # print(loss)

    if task_type == 'depth':
        # depth loss: l1 norm
        loss = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)

    if task_type == 'normal':
        # normal loss: dot product
        loss = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)

    return loss

def model_fit_top_q(x_pred, x_output, task_type, top_q=-1, std_b=0.0):
    binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).cuda()
    num_samples = x_pred.size(0)

    if task_type == 'semantic':
        # semantic loss: depth-wise cross entropy
        x_pred = F.log_softmax(x_pred, dim=1)
        loss = F.nll_loss(x_pred, x_output, ignore_index=-1, reduction='none')
        loss = loss.sum(dim=[1, 2]) / ((x_output != -1).sum() / num_samples)
        std = torch.std(loss, unbiased=False)
        if std > std_b:
            loss, indices = torch.topk(loss, k=top_q)
            print('111')
        else:
            loss = torch.mean(loss)

    elif task_type == 'depth':
        num_nonzero = binary_mask.sum()
        loss = torch.abs(x_pred - x_output) * binary_mask
        loss = torch.sum(loss, dim=[1, 2, 3]) / (num_nonzero / num_samples)

        std = torch.std(loss, unbiased=False)
        if std > std_b:
            loss, indices = torch.topk(loss, k=top_q)
            print('11')
        else:
            loss = torch.mean(loss)

    elif task_type == 'normal':
        # normal loss: dot product
        num_nonzero = binary_mask.sum()
        loss = 1 - torch.sum((x_pred * x_output) * binary_mask, dim=[1, 2, 3]) / (num_nonzero / num_samples)

        std = torch.std(loss, unbiased=False)
        if std > std_b:
            loss, indices = torch.topk(loss, k=top_q)
            print('1')
        else:
            loss = torch.mean(loss)

    else:
        loss = None

    return loss

def model_fit_per_samples(x_pred, x_output, task_type):
    # binary mark to mask out undefined pixel space
    binary_mask = (torch.sum(x_output, dim=1) != 0).float().unsqueeze(1).cuda()
    num_samples = x_pred.size(0)

    if task_type == 'semantic':
        # semantic loss: depth-wise cross entropy
        loss = F.nll_loss(x_pred, x_output, ignore_index=-1, reduction='none')
        loss = loss.sum(dim=[1, 2]) / ((x_output != -1).sum() / num_samples)
        # loss2 = F.nll_loss(x_pred, x_output, ignore_index=-1)
        # loss_temp = loss.mean()
        # if loss_temp != loss2:
        #     raise ValueError('Error')

    if task_type == 'depth':
        num_nonzero = binary_mask.sum()
        loss = torch.abs(x_pred - x_output) * binary_mask
        loss = torch.sum(loss, dim=[1, 2, 3]) / (num_nonzero / num_samples)

        # loss2 = torch.sum(torch.abs(x_pred - x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)
        # loss_temp = loss.mean()
        # if loss_temp != loss2:
        #     raise ValueError('Error')
    if task_type == 'normal':
        # normal loss: dot product
        num_nonzero = binary_mask.sum()
        loss = 1 - torch.sum((x_pred * x_output) * binary_mask, dim=[1,2,3]) / (num_nonzero / num_samples)

        # loss2 = 1 - torch.sum((x_pred * x_output) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)
        # loss_temp = loss.mean()
        # if loss_temp != loss2:
        #     raise ValueError('Error')
    return loss


# Legacy: compute mIoU and Acc. for each image and average across all images.

# def compute_miou(x_pred, x_output):
#     _, x_pred_label = torch.max(x_pred, dim=1)
#     x_output_label = x_output
#     batch_size = x_pred.size(0)
#     class_nb = x_pred.size(1)
#     device = x_pred.device
#     for i in range(batch_size):
#         true_class = 0
#         first_switch = True
#         invalid_mask = (x_output[i] >= 0).float()
#         for j in range(class_nb):
#             pred_mask = torch.eq(x_pred_label[i], j * torch.ones(x_pred_label[i].shape).long().to(device))
#             true_mask = torch.eq(x_output_label[i], j * torch.ones(x_output_label[i].shape).long().to(device))
#             mask_comb = pred_mask.float() + true_mask.float()
#             union = torch.sum((mask_comb > 0).float() * invalid_mask)  # remove non-defined pixel predictions
#             intsec = torch.sum((mask_comb > 1).float())
#             if union == 0:
#                 continue
#             if first_switch:
#                 class_prob = intsec / union
#                 first_switch = False
#             else:
#                 class_prob = intsec / union + class_prob
#             true_class += 1
#         if i == 0:
#             batch_avg = class_prob / true_class
#         else:
#             batch_avg = class_prob / true_class + batch_avg
#     return batch_avg / batch_size
#
#
# def compute_iou(x_pred, x_output):
#     _, x_pred_label = torch.max(x_pred, dim=1)
#     x_output_label = x_output
#     batch_size = x_pred.size(0)
#     for i in range(batch_size):
#         if i == 0:
#             pixel_acc = torch.div(
#                 torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
#                 torch.sum((x_output_label[i] >= 0).float()))
#         else:
#             pixel_acc = pixel_acc + torch.div(
#                 torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
#                 torch.sum((x_output_label[i] >= 0).float()))
#     return pixel_acc / batch_size


# New mIoU and Acc. formula: accumulate every pixel and average across all pixels in all images

def depth_error(x_pred, x_output):
    device = x_pred.device
    binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device)
    x_pred_true = x_pred.masked_select(binary_mask)
    x_output_true = x_output.masked_select(binary_mask)
    abs_err = torch.abs(x_pred_true - x_output_true)
    rel_err = torch.abs(x_pred_true - x_output_true) / torch.abs(x_output_true)
    return (torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item(), \
           (torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item()

def normal_error(x_pred, x_output):
    binary_mask = (torch.sum(x_output, dim=1) != 0)
    error = torch.acos(torch.clamp(torch.sum(x_pred * x_output, 1).masked_select(binary_mask), -1, 1)).detach().cpu().numpy()
    error = np.degrees(error)
    return np.mean(error), np.median(error), np.mean(error < 11.25), np.mean(error < 22.5), np.mean(error < 30)


"""
=========== Universal Multi-task Trainer ===========
"""

# def multi_task_trainer(train_loader, test_loader, multi_task_model, device, optimizer, scheduler, opt, total_epoch=200):
#     start_time = time.time()
#     train_batch = len(train_loader)
#     test_batch = len(test_loader)
#     T = opt.temp
#     avg_cost = np.zeros([total_epoch, 24], dtype=np.float32)
#     lambda_weight = np.ones([3, total_epoch])
#     for index in range(total_epoch):
#         epoch_start_time = time.time()
#         cost = np.zeros(24, dtype=np.float32)
#
#         # apply Dynamic Weight Average
#         if opt.weight == 'dwa':
#             if index == 0 or index == 1:
#                 lambda_weight[:, index] = 1.0
#             else:
#                 w_1 = avg_cost[index - 1, 0] / avg_cost[index - 2, 0]
#                 w_2 = avg_cost[index - 1, 3] / avg_cost[index - 2, 3]
#                 w_3 = avg_cost[index - 1, 6] / avg_cost[index - 2, 6]
#                 lambda_weight[0, index] = 3 * np.exp(w_1 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
#                 lambda_weight[1, index] = 3 * np.exp(w_2 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
#                 lambda_weight[2, index] = 3 * np.exp(w_3 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
#
#         # iteration for all batches
#         multi_task_model.train()
#         train_dataset = iter(train_loader)
#         conf_mat = ConfMatrix(multi_task_model.class_nb)
#         for k in range(train_batch):
#             train_data, train_label, train_depth, train_normal = train_dataset.next()
#             train_data, train_label = train_data.to(device), train_label.long().to(device)
#             train_depth, train_normal = train_depth.to(device), train_normal.to(device)
#
#             train_pred, logsigma = multi_task_model(train_data)
#
#             optimizer.zero_grad()
#             train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
#                           model_fit(train_pred[1], train_depth, 'depth'),
#                           model_fit(train_pred[2], train_normal, 'normal')]
#
#             if opt.weight == 'equal' or opt.weight == 'dwa':
#                 loss = sum([lambda_weight[i, index] * train_loss[i] for i in range(3)])
#                 #loss = sum([w[i] * train_loss[i] for i in range(3)])
#             else:
#                 loss = sum(1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2 for i in range(3))
#
#             loss.backward()
#             optimizer.step()
#
#             # accumulate label prediction for every pixel in training images
#             conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten())
#
#             cost[0] = train_loss[0].item()
#             cost[3] = train_loss[1].item()
#             cost[4], cost[5] = depth_error(train_pred[1], train_depth)
#             cost[6] = train_loss[2].item()
#             cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)
#             avg_cost[index, :12] += cost[:12] / train_batch
#
#         # compute mIoU and acc
#         avg_cost[index, 1:3] = conf_mat.get_metrics()
#
#         # evaluating test data
#         multi_task_model.eval()
#         conf_mat = ConfMatrix(multi_task_model.class_nb)
#         with torch.no_grad():  # operations inside don't track history
#             test_dataset = iter(test_loader)
#             for k in range(test_batch):
#                 test_data, test_label, test_depth, test_normal = test_dataset.next()
#                 test_data, test_label = test_data.to(device), test_label.long().to(device)
#                 test_depth, test_normal = test_depth.to(device), test_normal.to(device)
#
#                 test_pred, _ = multi_task_model(test_data)
#                 test_loss = [model_fit(test_pred[0], test_label, 'semantic'),
#                              model_fit(test_pred[1], test_depth, 'depth'),
#                              model_fit(test_pred[2], test_normal, 'normal')]
#
#                 conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten())
#
#                 cost[12] = test_loss[0].item()
#                 cost[15] = test_loss[1].item()
#                 cost[16], cost[17] = depth_error(test_pred[1], test_depth)
#                 cost[18] = test_loss[2].item()
#                 cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(test_pred[2], test_normal)
#                 avg_cost[index, 12:] += cost[12:] / test_batch
#
#             # compute mIoU and acc
#             avg_cost[index, 13:15] = conf_mat.get_metrics()
#
#         scheduler.step()
#         epoch_end_time = time.time()
#         print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} ||'
#             'TEST: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} | {:.4f}'
#             .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 3],
#                     avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 6], avg_cost[index, 7], avg_cost[index, 8],
#                     avg_cost[index, 9], avg_cost[index, 10], avg_cost[index, 11], avg_cost[index, 12], avg_cost[index, 13],
#                     avg_cost[index, 14], avg_cost[index, 15], avg_cost[index, 16], avg_cost[index, 17], avg_cost[index, 18],
#                     avg_cost[index, 19], avg_cost[index, 20], avg_cost[index, 21], avg_cost[index, 22], avg_cost[index, 23], epoch_end_time-epoch_start_time))
#     end_time = time.time()
#     print("Training time: ", end_time - start_time)

"""
=========== Universal Single-task Trainer ===========
"""

def single_task_trainer(train_loader, test_loader, single_task_model, optimizer, scheduler, opt, total_epoch=200):
    train_batch = len(train_loader)
    test_batch = len(test_loader)
    avg_cost = np.zeros([total_epoch, 24], dtype=np.float32)
    task_id = opt.task_id

    define_wandb_metric_single_task(task_id)

    for index in range(total_epoch):
        cost = np.zeros(24, dtype=np.float32)

        # iteration for all batches
        single_task_model.train()
        conf_mat = ConfMatrix(single_task_model.class_nb)

        for k, data in enumerate(train_loader):
            train_data, train_label, train_depth, train_normal, indices = data
            train_data, train_label = train_data.cuda(), train_label.long().cuda()
            train_depth, train_normal = train_depth.cuda(), train_normal.cuda()

            train_pred, logsigma = single_task_model(train_data)

            optimizer.zero_grad()
            if task_id == 0:
                train_loss = model_fit(train_pred[0], train_label, 'semantic')
                train_loss.backward()
                optimizer.step()

                conf_mat.update(train_pred[0].argmax(1).flatten().cpu(), train_label.flatten().cpu())
                cost[0] = train_loss.item()

            if task_id == 1:
                train_loss = model_fit(train_pred[1], train_depth, 'depth')
                train_loss.backward()
                optimizer.step()

                cost[3] = train_loss.item()
                cost[4], cost[5] = depth_error(train_pred[1], train_depth)

            if task_id == 2:
                train_loss = model_fit(train_pred[2], train_normal, 'normal')
                train_loss.backward()
                optimizer.step()

                cost[6] = train_loss.item()
                cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)

            avg_cost[index, :12] += cost[:12] / train_batch

        if task_id == 0:
            avg_cost[index, 1:3] = conf_mat.get_metrics()

        del train_loss
        del train_pred
        del logsigma

        # evaluating test data
        single_task_model.eval()
        conf_mat = ConfMatrix(single_task_model.class_nb)
        if (index + 1) % opt.eval_freq == 0:
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    test_data, test_label, test_depth, test_normal, indices = data
                    test_data, test_label = test_data.cuda(), test_label.long().cuda()
                    test_depth, test_normal = test_depth.cuda(), test_normal.cuda()

                    test_pred, logsigma = single_task_model(test_data)

                    if task_id == 0:
                        test_loss = model_fit(test_pred[0], test_label, 'semantic')

                        conf_mat.update(test_pred[0].argmax(1).flatten().cpu(), test_label.flatten().cpu())
                        cost[12] = test_loss.item()

                    if task_id == 1:
                        test_loss = model_fit(test_pred[1], test_depth, 'depth')

                        cost[15] = test_loss.item()
                        cost[16], cost[17] = depth_error(test_pred[1], test_depth)

                    if task_id == 2:
                        test_loss = model_fit(test_pred[2], test_normal, 'normal')

                        cost[18] = test_loss.item()
                        cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(test_pred[2], test_normal)

                    avg_cost[index, 12:] += cost[12:] / test_batch

                if task_id == 0:
                    avg_cost[index, 13:15] = conf_mat.get_metrics()

            print_details_single_tasks(avg_cost=avg_cost, index=index, task_id=task_id, mode='val')
        else:
            print_details_single_tasks(avg_cost=avg_cost, index=index, task_id=task_id, mode='train')

        scheduler.step()


''' ===== multi task MGD trainer ==== '''
def multi_task_mgd_trainer(train_loader, test_loader, multi_task_model,
                           optimizer, scheduler, opt,
                           epoch=200, method='sumloss', alpha=0.5, seed=0, degree=0.5, eval_freq=4):
    start_time = time.time()

    times_conflict = 0

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def graddrop(grads):
        P = 0.5 * (1. + grads.sum(1) / (grads.abs().sum(1) + 1e-8))
        U = torch.rand_like(grads[:, 0])
        M = P.gt(U).view(-1, 1) * grads.gt(0) + P.lt(U).view(-1, 1) * grads.lt(0)
        g = (grads * M.float()).mean(1)
        return g

    def mgd(grads):
        grads_cpu = grads.t().cpu()
        sol, min_norm = MinNormSolver.find_min_norm_element([grads_cpu[t] for t in range(grads.shape[-1])])
        w = torch.FloatTensor(sol).to(grads.device)
        g = grads.mm(w.view(-1, 1)).view(-1)
        return g

    def pcgrad(grads, rng):
        grad_vec = grads.t()
        num_tasks = 3

        shuffled_task_indices = np.zeros((num_tasks, num_tasks - 1), dtype=int)
        for i in range(num_tasks):
            task_indices = np.arange(num_tasks)
            task_indices[i] = task_indices[-1]
            shuffled_task_indices[i] = task_indices[:-1]
            rng.shuffle(shuffled_task_indices[i])
        shuffled_task_indices = shuffled_task_indices.T

        normalized_grad_vec = grad_vec / (
                grad_vec.norm(dim=1, keepdim=True) + 1e-8
        )  # num_tasks x dim
        modified_grad_vec = deepcopy(grad_vec)
        for task_indices in shuffled_task_indices:
            normalized_shuffled_grad = normalized_grad_vec[
                task_indices
            ]  # num_tasks x dim
            dot = (modified_grad_vec * normalized_shuffled_grad).sum(
                dim=1, keepdim=True
            )  # num_tasks x dim
            modified_grad_vec -= torch.clamp_max(dot, 0) * normalized_shuffled_grad
        g = modified_grad_vec.mean(dim=0)
        return g

    def cagrad(grads, alpha=0.5, rescale=1):
        GG = grads.t().mm(grads).cpu()  # [num_tasks, num_tasks]
        g0_norm = (GG.mean() + 1e-8).sqrt()  # norm of the average gradient

        x_start = np.ones(3) / 3
        bnds = tuple((0, 1) for x in x_start)
        cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})
        A = GG.numpy()
        b = x_start.copy()
        c = (alpha * g0_norm + 1e-8).item()

        def objfn(x):
            return (x.reshape(1, 3).dot(A).dot(b.reshape(3, 1)) + c * np.sqrt(
                x.reshape(1, 3).dot(A).dot(x.reshape(3, 1)) + 1e-8)).sum()

        res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
        w_cpu = res.x
        ww = torch.Tensor(w_cpu).to(grads.device)
        gw = (grads * ww.view(1, -1)).sum(1)
        gw_norm = gw.norm()
        lmbda = c / (gw_norm + 1e-8)
        g = grads.mean(1) + lmbda * gw
        if rescale == 0:
            return g
        elif rescale == 1:
            return g / (1 + alpha ** 2)
        else:
            return g / (1 + alpha)

    def cagrad_weight(grads, alpha=0.5, rescale=1):
        GG = grads.t().mm(grads).cpu()  # [num_tasks, num_tasks]
        g0_norm = (GG.mean() + 1e-8).sqrt()  # norm of the average gradient

        x_start = np.ones(3) / 3
        bnds = tuple((0, 1) for x in x_start)
        cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})
        A = GG.numpy()
        b = x_start.copy()
        c = (alpha * g0_norm + 1e-8).item()

        def objfn(x):
            return (x.reshape(1, 3).dot(A).dot(b.reshape(3, 1)) + c * np.sqrt(
                x.reshape(1, 3).dot(A).dot(x.reshape(3, 1)) + 1e-8)).sum()

        res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
        w_cpu = res.x
        ww = torch.Tensor(w_cpu).to(grads.device)
        gw = (grads * ww.view(1, -1)).sum(1)
        gw_norm = gw.norm()
        lmbda = c / (gw_norm + 1e-8)

        weight = torch.Tensor([1 / 3, 1 / 3, 1 / 3]).cuda()
        weight = weight + lmbda * ww

        if rescale == 0:
            return weight
        elif rescale == 1:
            weight = weight / (1 + alpha ** 2)
            return weight
        else:
            return weight / (1 + alpha)

    def cagrad_check(grads, alpha=0.5, rescale=1, degree=0.1):
        grad_vec = grads.t()
        grad_direct = grad_vec.sum(dim=0)

        grad_actual = grad_direct.repeat(3, 1)

        normalized_grad_vec = grad_vec / (
                grad_vec.norm(dim=1, keepdim=True) + 1e-8
        )

        normalized_grad_actual = grad_actual / (
                grad_actual.norm(dim=1, keepdim=True) + 1e-8
        )

        dot = (normalized_grad_vec * normalized_grad_actual).sum(dim=1)

        ans = (dot < degree).sum()

        if ans > 0:
            print('Conflict!!Using CAGrad!')
            return cagrad(grads, alpha, rescale)
        else:
            return grad_direct

    def checkgrad(grads, degree=0.5):
        grad_vec = grads.t()

        grad_actual = grad_vec.sum(dim=0).repeat(3,1)

        normalized_grad_vec = grad_vec / (
                grad_vec.norm(dim=1, keepdim=True) + 1e-8
        )

        normalized_grad_actual = grad_actual / (
            grad_actual.norm(dim=1, keepdim=True) + 1e-8
        )

        dot = (normalized_grad_vec * normalized_grad_actual).sum(dim=1)

        ans = (dot<degree)
        total = ans.sum()

        print(f'conflict: {ans}, {total > 0}')

        return total > 0

    def grad2vec(m, grads, grad_dims, task):
        # store the gradients
        grads[:, task].fill_(0.0)
        cnt = 0
        for name, p in m.shared_parameters().items():
            grad = p.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone()
                beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
                en = sum(grad_dims[:cnt + 1])
                grads[beg:en, task].copy_(grad_cur.data.view(-1))
            cnt += 1

    def freeze_network(m, flags):
        cur = 0
        for name, param in m.shared_parameters().items():
            param.requires_grad = flags[cur]
            cur += 1

    def unfreeze_network(m):
        for name, param in m.shared_parameters().items():
            param.requires_grad = True

    def overwrite_grad(m, newgrad, grad_dims):
        newgrad = newgrad * 3  # to match the sum loss
        cnt = 0
        for name, param in m.shared_parameters().items():
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            this_grad = newgrad[beg: en].contiguous().view(param.data.size())
            param.grad = this_grad.data.clone()
            cnt += 1

    def overwrite_grad_with_freeze_level(m, newgrad, grad_dims):
        newgrad = newgrad * 3  # to match the sum loss
        beg = 0
        cnt = 0
        for name, param in m.shared_parameters().items():
            if param.requires_grad is True:
                end = beg + grad_dims[cnt]
                this_grad = newgrad[beg: end].contiguous().view(param.data.size())
                param.grad = this_grad.data.clone()
                beg = end
            cnt += 1

    def grad2vec_list(m):
        grad_list = []
        for name, param in m.shared_parameters().items():
            grad = param.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone().view(-1)
                grad_list.append(grad_cur)
        return grad_list

    def get_layer_dict(m):
        shared_parameters = m.shared_parameters()
        name_list = list(shared_parameters.keys())

        layer_dict = {}
        for i, name in enumerate(name_list):
            if '.weight' in name:
                name = name.replace('.weight', '')
            elif '.bias' in name:
                name = name.replace('.bias', '')

            if name not in layer_dict:
                layer_dict[name] = [i]
            else:
                layer_dict[name].append(i)

        return layer_dict

    def split_layer(grad_list, name_dict):
        grad_new = []
        for key, value in name_dict.items():
            grad = [grad_list[i] for i in value]
            grad = torch.cat(grad)
            grad_new.append(grad)

        return grad_new

    def savegrad(module_list):
        grad_saved = []
        for mm in module_list:
            for p in mm.parameters():
                grad = p.grad
                if grad is not None:
                    grad_cur = grad.data.detach().clone()
                    grad_saved.append(grad_cur)

        return grad_saved
    #

    def reshape_grad(g, grad_dims):
        grad_all = []
        for i in range(len(grad_dims)):
            beg = 0 if i == 0 else sum(grad_dims[:i])
            en = sum(grad_dims[:i+1])
            this_grad = g[beg:en].data.detach().clone()
            grad_all.append(this_grad)
        return grad_all

    def overwrite_grad_with_weight(module_list, task_grad, weight, lr):
        p_all = []
        grad_vec = []
        variables = []
        for mm in module_list:
            for name, p in mm.named_parameters():
                p.fast = None
                p_all.append(p)

        for k, p in enumerate(p_all):
            g = task_grad[0][k] * weight[0] + task_grad[1][k] * weight[1] + task_grad[2][k] * weight[2]
            variables.append(g)
            value = p.data.detach().clone().add(g, alpha=-lr)
            p.fast = value
            grad_vec.append(g.view(-1))
        grad_vec = torch.cat(grad_vec)
        return grad_vec, variables

    def overwrite_grad_with_weight_final(module_list, task_grad, weight):
        p_all = []

        for mm in module_list:
            for name, p in mm.named_parameters():
                p.fast = None
                p_all.append(p)

        for k, p in enumerate(p_all):
            p.fast = None
            p.grad = (task_grad[0][k] * weight[0] + task_grad[1][k] * weight[1] + task_grad[2][k] * weight[2]).data.clone()

    def overwrite_task_dependent_module(module_list, grad_task_dependent, lr):
        p_all = []
        for mm in module_list:
            for p in mm.parameters():
                p_all.append(p)

        for k, p in enumerate(p_all):
            p.data.add_(grad_task_dependent[k], alpha=-lr)
            p.requires_grad = False

    def overwrite_task_dependent_module_final(module_list, grad_task_dependent, data_task_dependent):
        p_all = []
        for mm in module_list:
            for p in mm.parameters():
                p_all.append(p)

        for k, p in enumerate(p_all):
            p.data = data_task_dependent[k].data.detach().clone()
            p.requires_grad = True
            p.grad = grad_task_dependent[k].data.detach().clone()

    def copy_task_dependent_module(model):
        p_all = []
        for mm in model.task_dependent_modules():
            for p in mm.parameters():
                p_all.append(p.data.detach().clone())

        return p_all

    define_wandb_metric()
    # rng = np.random.default_rng()
    grad_dims = []
    for key, param in multi_task_model.shared_parameters().items():
        grad_dims.append(param.data.numel())

    # grads = torch.Tensor(sum(grad_dims), 3).cuda()
    rng = np.random.default_rng()

    layer_dict = get_layer_dict(multi_task_model)
    layer_name = list(layer_dict.keys())

    layer_wise_angle = OrderedDict()
    layer_wise_dot_value = OrderedDict()
    layer_wise_task_angle = OrderedDict()
    layer_wise_task_dot_value = OrderedDict()
    task_diff = OrderedDict()
    diff = OrderedDict()

    for name in layer_name:
        layer_wise_angle[name] = []
        layer_wise_dot_value[name] = []
        layer_wise_task_angle[name] = []
        layer_wise_task_dot_value[name] = []
        task_diff[name] = []
        diff[name] = []

    train_batch = len(train_loader)
    test_batch = len(test_loader)
    T = opt.temp
    avg_cost = np.zeros([epoch, 24], dtype=np.float32)
    lambda_weight = np.ones([3, epoch])

    timer = Timer()

    running_avg_1 = Running_Avg(100)
    running_avg_2 = Running_Avg(100)
    running_avg_3 = Running_Avg(100)
    inner_index = 0
    total_cnt = 0
    cnt = 0

    turning_cnt = 0

    angle_all = []
    for index in range(epoch):

        avg_t1_angle = Averager()
        avg_t2_angle = Averager()
        avg_t3_angle = Averager()
        avg_angle = Averager()

        cost = np.zeros(24, dtype=np.float32)

        # apply Dynamic Weight Average
        if opt.weight == 'dwa':
            if index == 0 or index == 1:
                lambda_weight[:, index] = 1.0
            else:
                w_1 = avg_cost[index - 1, 0] / avg_cost[index - 2, 0]
                w_2 = avg_cost[index - 1, 3] / avg_cost[index - 2, 3]
                w_3 = avg_cost[index - 1, 6] / avg_cost[index - 2, 6]
                lambda_weight[0, index] = 3 * np.exp(w_1 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
                lambda_weight[1, index] = 3 * np.exp(w_2 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
                lambda_weight[2, index] = 3 * np.exp(w_3 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))

        # iteration for all batches
        multi_task_model.train()
        conf_mat = ConfMatrix(multi_task_model.class_nb)

        for k, data in enumerate(train_loader):
            grads = torch.Tensor(sum(grad_dims), 3).cuda()
            train_data, train_label, train_depth, train_normal, indices = data
            train_data, train_label = train_data.cuda(), train_label.long().cuda()
            train_depth, train_normal = train_depth.cuda(), train_normal.cuda()

            train_pred, logsigma = multi_task_model(train_data)

            # if opt.top_q <= 0:
            train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
                          model_fit(train_pred[1], train_depth, 'depth'),
                          model_fit(train_pred[2], train_normal, 'normal')]
            # else:
            # train_loss = [model_fit_top_q(train_pred[0], train_label, 'semantic', top_q=opt.top_q, std_b=opt.std),
            #               model_fit_top_q(train_pred[1], train_depth, 'depth', top_q=opt.top_q, std_b=opt.std),
            #               model_fit_top_q(train_pred[2], train_normal, 'normal', top_q=opt.top_q, std_b=opt.std)]

            train_loss_tmp = [0, 0, 0]

            if opt.weight == 'equal' or opt.weight == 'dwa':
                for i in range(3):
                    # train_loss_tmp[i] = train_loss[i] * lambda_weight[i, index]
                    train_loss_tmp[i] = train_loss[i]
            else:
                for i in range(3):
                    train_loss_tmp[i] = 1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2

            optimizer.zero_grad()
            if method == "graddrop":
                for i in range(3):
                    if i < 3:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                g3 = grads[:, 2].data.clone()

                angle_all.append(torch.stack([cos(g1, g2), cos(g1, g3), cos(g2, g3)]).view(-1))

                g = graddrop(grads)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            elif method == "mgd":
                for i in range(3):
                    if i < 3:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                g3 = grads[:, 2].data.clone()

                angle_all.append(torch.stack([cos(g1, g2), cos(g1, g3), cos(g2, g3)]).view(-1))

                g = mgd(grads)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            elif method == "pcgrad":
                for i in range(3):
                    if i < 3:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                g3 = grads[:, 2].data.clone()

                angle_all.append(torch.stack([cos(g1, g2), cos(g1, g3), cos(g2, g3)]).view(-1))

                g = pcgrad(grads, rng)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            # elif method == 'regularizer':
            #     grad_all = []
            #
            #     shared_grad = []
            #     for mm in multi_task_model.shared_modules():
            #         shared_grad.extend(list(mm.parameters()))
            #
            #     for i in range(3):
            #         g = []
            #         grad = torch.autograd.grad(train_loss_tmp[i], shared_grad, create_graph=True)
            #         for p in grad:
            #             g.append(p.view(-1))
            #
            #         g = torch.cat(g)
            #         grad_all.append(g)
            #
            #     grad_all = torch.stack(grad_all)
            #
            #     grad_all = F.normalize(grad_all, dim=1)
            #
            #     m = torch.mm(grad_all, grad_all.T)
            #
            #     regularizer = 1 - m.mean()
            #
            #     wandb.log({'regularizer': regularizer})
            #     if opt.weight == 'equal' or opt.weight == 'dwa':
            #         loss = sum([lambda_weight[i, index] * train_loss[i] for i in range(3)])
            #         # loss = sum([w[i] * train_loss[i] for i in range(3)])
            #     else:
            #         loss = sum(1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2 for i in range(3))
            #
            #     # loss = loss + opt.beta * regularizer
            #     loss.backward()
            #     optimizer.step()
            #     del regularizer
            #     del m
            #     del grad_all
            # # elif method == 'maml':
            # #
            # #     lr = torch.tensor(scheduler.get_lr()).cuda()
            # #     omega = torch.tensor(opt.omega).cuda()
            # #
            # #     weight = torch.nn.parameter.Parameter(torch.tensor([1.0, 1.0, 1.0]), requires_grad=True).cuda()
            # #
            # #     loss_baseline = [l.detach().clone() for l in train_loss]
            # #     loss_baseline = torch.stack(loss_baseline).view(-1)
            # #
            # #     grad_all = []
            # #     for i in range(3):
            # #         if i < 2:
            # #             train_loss_tmp[i].backward(retain_graph=True)
            # #         else:
            # #             train_loss_tmp[i].backward()
            # #         grad = savegrad(multi_task_model.shared_modules())
            # #         grad_all.append(grad)
            # #         multi_task_model.zero_grad_shared_modules()
            # #
            # #     del train_loss
            # #     del train_loss_tmp
            # #
            # #     # for k, m in enumerate(multi_task_model.shared_modules()):
            # #     #     cnt = 0
            # #     #     for p in m.parameters():
            # #     #         if cnt == 3 and k == 1:
            # #     #             print(f"[Former] Test, shared_modules data {k} is {p.data}")
            # #     #         cnt+=1
            # #
            # #     grad_task_dependent = savegrad(multi_task_model.task_dependent_modules())
            # #     data_task_dependent = copy_task_dependent_module(multi_task_model)
            # #
            # #     overwrite_task_dependent_module(multi_task_model.task_dependent_modules(), grad_task_dependent, lr)
            # #
            # #     best_loss = 100000000
            # #     best_weight = weight.clone()
            # #     print(
            # #         f'------------------------------------------------------Batch: {k}----------------------------------------------------')
            # #     for j in range(opt.n_steps):
            # #         overwrite_grad_with_weight(multi_task_model.shared_modules(), task_grad=grad_all, weight=weight, lr=lr)
            # #
            # #         # shared_grad = []
            # #         # for mm in multi_task_model.shared_modules():
            # #         #     shared_grad.extend(list(mm.parameters()))
            # #
            # #         train_pred, logsigma = multi_task_model(train_data)
            # #
            # #         train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
            # #                       model_fit(train_pred[1], train_depth, 'depth'),
            # #                       model_fit(train_pred[2], train_normal, 'normal')]
            # #
            # #         train_loss = torch.stack(train_loss).view(-1)
            # #
            # #         gap = (train_loss - loss_baseline)
            # #         gap = gap / loss_baseline
            # #         loss = gap.mean() * 100
            # #
            # #         if loss < best_loss:
            # #             best_loss = loss
            # #             best_weight = weight.clone()
            # #
            # #         if j < opt.n_steps - 1:
            # #             g = torch.autograd.grad(loss, weight, create_graph=False)
            # #             g = omega * g[0]
            # #             weight = weight - g
            # #
            # #         print(f'gap: {gap.data} | loss: {loss.data} | weight: {weight.data}')
            # #         # for i in range(3):
            # #         #     g = []
            # #         #     grad = torch.autograd.grad(train_loss[i], shared_grad, create_graph=True)
            # #         #     for p in grad:
            # #         #         g.append(p.view(-1))
            # #         #
            # #         #     g = torch.cat(g)
            # #         #     grad_next_all.append(g)
            # #         #
            # #         # grad_next_all = torch.stack(grad_next_all)
            # #         #
            # #         # grad_next_all = F.normalize(grad_next_all, dim=1)
            # #         #
            # #         # m = torch.mm(grad_next_all, grad_next_all.T)
            # #         #
            # #         # regularizer = 1 - m.mean()
            # #         #
            # #         # grad = torch.autograd.grad(regularizer, weight, create_graph=False)
            # #         #
            # #         # weight = weight - opt.omega * grad
            # #         #
            # #         # print(f'regularizer: {regularizer}')
            # #
            # #     overwrite_grad_with_weight_final(multi_task_model.shared_modules(), task_grad=grad_all, weight=best_weight)
            # #     overwrite_task_dependent_module_final(multi_task_model.task_dependent_modules(), grad_task_dependent, data_task_dependent)
            # #
            # #     # for k, m in enumerate(multi_task_model.shared_modules()):
            # #     #     cnt = 0
            # #     #     for p in m.parameters():
            # #     #         if cnt == 3 and k == 1:
            # #     #             print(f"[After] Test, shared_modules data {k} is {p.data}")
            # #     #         cnt += 1
            # #
            # #     optimizer.step()
            # #
            # # elif method == 'maml2':
            # #
            # #     lr = torch.tensor(scheduler.get_lr()).cuda()
            # #     omega = torch.tensor(opt.omega).cuda()
            # #
            # #     weight = torch.nn.parameter.Parameter(torch.tensor([1.0, 1.0, 1.0]), requires_grad=True).cuda()
            # #
            # #     loss_baseline = [l.detach().clone() for l in train_loss]
            # #     loss_baseline = torch.stack(loss_baseline).view(-1)
            # #
            # #     grad_all = []
            # #     for i in range(3):
            # #         if i < 2:
            # #             train_loss_tmp[i].backward(retain_graph=True)
            # #         else:
            # #             train_loss_tmp[i].backward()
            # #         grad = savegrad(multi_task_model.shared_modules())
            # #         grad_all.append(grad)
            # #         multi_task_model.zero_grad_shared_modules()
            # #
            # #     del train_loss
            # #     del train_loss_tmp
            # #
            # #     # for k, m in enumerate(multi_task_model.shared_modules()):
            # #     #     cnt = 0
            # #     #     for p in m.parameters():
            # #     #         if cnt == 3 and k == 1:
            # #     #             print(f"[Former] Test, shared_modules data {k} is {p.data}")
            # #     #         cnt+=1
            # #
            # #     grad_task_dependent = savegrad(multi_task_model.task_dependent_modules())
            # #     data_task_dependent = copy_task_dependent_module(multi_task_model)
            # #
            # #     overwrite_task_dependent_module(multi_task_model.task_dependent_modules(), grad_task_dependent, lr)
            # #
            # #     best_loss = 100000000
            # #     best_weight = weight.clone()
            # #     print(
            # #         f'------------------------------------------------------Batch: {k}----------------------------------------------------')
            # #     for j in range(opt.n_steps):
            # #         variable = overwrite_grad_with_weight(multi_task_model.shared_modules(), task_grad=grad_all, weight=weight, lr=lr)
            # #
            # #         train_pred, logsigma = multi_task_model(train_data)
            # #
            # #         train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
            # #                       model_fit(train_pred[1], train_depth, 'depth'),
            # #                       model_fit(train_pred[2], train_normal, 'normal')]
            # #
            # #         grad_next_all = []
            # #         for i in range(3):
            # #             g = []
            # #             grad = torch.autograd.grad(train_loss[i], variable, create_graph=True)
            # #             for p in grad:
            # #                 g.append(p.view(-1))
            # #
            # #             g = torch.cat(g)
            # #             grad_next_all.append(g)
            # #
            # #         grad_next_all = torch.stack(grad_next_all)
            # #
            # #         grad_next_all = F.normalize(grad_next_all, dim=1)
            # #
            # #         m = torch.mm(grad_next_all, grad_next_all.T)
            # #
            # #         regularizer = 1 - m.mean()
            # #
            # #         grad = torch.autograd.grad(regularizer, weight, create_graph=False)
            # #
            # #         weight = weight - omega * grad[0]
            # #
            # #         print(f'regularizer: {regularizer}')
            # #
            # #     overwrite_grad_with_weight_final(multi_task_model.shared_modules(), task_grad=grad_all, weight=best_weight)
            # #     overwrite_task_dependent_module_final(multi_task_model.task_dependent_modules(), grad_task_dependent, data_task_dependent)
            # #
            # #     # for k, m in enumerate(multi_task_model.shared_modules()):
            # #     #     cnt = 0
            # #     #     for p in m.parameters():
            # #     #         if cnt == 3 and k == 1:
            # #     #             print(f"[After] Test, shared_modules data {k} is {p.data}")
            # #     #         cnt += 1
            # #
            # #     optimizer.step()
            # elif method == 'maml_cagrad':
            #     lr = get_lr(optimizer)
            #
            #     std = torch.tensor(opt.std).cuda()
            #     mean = torch.tensor(opt.mean).cuda()
            #     omega = torch.tensor(opt.omega).cuda()
            #
            #     loss_baseline = [l.detach().clone() for l in train_loss]
            #     loss_baseline = torch.stack(loss_baseline).view(-1)
            #
            #     grad_all = []
            #     for i in range(3):
            #         if i < 2:
            #             train_loss_tmp[i].backward(retain_graph=True)
            #         else:
            #             train_loss_tmp[i].backward()
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = savegrad(multi_task_model.shared_modules())
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     weight = cagrad_weight(grads, alpha, rescale=1)
            #     weight = torch.nn.parameter.Parameter(weight, requires_grad=True).cuda()
            #
            #     grad_task_dependent = savegrad(multi_task_model.task_dependent_modules())
            #     data_task_dependent = copy_task_dependent_module(multi_task_model)
            #
            #     overwrite_task_dependent_module(multi_task_model.task_dependent_modules(), grad_task_dependent, lr)
            #
            #     best_loss = 100000000
            #     best_weight = weight.clone()
            #
            #     print(
            #         f'------------------------------------------------------Batch: {k}----------------------------------------------------')
            #     for j in range(opt.n_steps):
            #         grad_vec, variables = overwrite_grad_with_weight(multi_task_model.shared_modules(), task_grad=grad_all, weight=weight, lr=lr)
            #         # freeze the running mean and running variance of BN layers
            #         multi_task_model.eval()
            #
            #         inner_train_pred, inner_logsigma = multi_task_model(train_data)
            #
            #         inner_train_loss = [model_fit(inner_train_pred[0], train_label, 'semantic'),
            #                             model_fit(inner_train_pred[1], train_depth, 'depth'),
            #                             model_fit(inner_train_pred[2], train_normal, 'normal')]
            #
            #         inner_train_loss = torch.stack(inner_train_loss).view(-1)
            #
            #         gap = (inner_train_loss - loss_baseline)
            #         gap = gap / loss_baseline
            #
            #         # calculate the regularizer
            #         temp = (loss_baseline - inner_train_loss).view(1, -1)
            #         temp = torch.mm(temp.T, temp)
            #         alpha = torch.tensor(lr).cuda()
            #         regularizer = - temp.mean() / ((grad_vec**2).sum() * alpha * alpha)
            #         # regularizer = torch.exp(regularizer)
            #         loss = mean * gap.mean() + std * gap.std() + omega * regularizer
            #         # loss = omega * regularizer
            #
            #         if loss < best_loss:
            #             best_loss = loss.detach().clone()
            #             best_weight = weight.detach().clone()
            #             # best_inner_loss = [inner_train_loss[i].detach().clone() for i in range(3)]
            #
            #         # next_grad = []
            #         # for i in range(3):
            #         #     g = []
            #         #     grad = torch.autograd.grad(inner_train_loss[i], variables, create_graph=True)
            #         #     for p in grad:
            #         #         g.append(p.view(-1))
            #         #
            #         #     g = torch.cat(g)
            #         #     next_grad.append(g)
            #         #     del g
            #         #     del grad
            #         # with torch.no_grad():
            #         #     m = torch.mm(next_grad, next_grad.T)
            #         #     test = 1 - m.mean()
            #
            #         print(f'gap: {gap.data} | loss: {loss.data} | weight: {weight.data} | regularizer: {regularizer.item()}')
            #
            #         # ignore the last iteration
            #         if j < opt.n_steps - 1:
            #             g = torch.autograd.grad(loss, weight, create_graph=False)
            #             weight = weight - g[0]
            #             del g
            #
            #         del inner_train_loss
            #         del inner_train_pred
            #         del inner_logsigma
            #         del loss
            #         del gap
            #         del regularizer
            #         del grad_vec
            #         del temp
            #         # del test
            #         # del grad
            #         # del g
            #         # del next_grad
            #         del variables
            #
            #     overwrite_grad_with_weight_final(multi_task_model.shared_modules(), task_grad=grad_all, weight=best_weight)
            #     overwrite_task_dependent_module_final(multi_task_model.task_dependent_modules(), grad_task_dependent, data_task_dependent)
            #
            #
            #     # overwrite_grad_with_weight_test(multi_task_model.shared_modules(), task_grad=grad_all, weight=best_weight, lr=lr)
            #     multi_task_model.train()
            #     optimizer.step()
            #
            #     # for group in optimizer.param_groups:
            #     #     weight_decay = group['weight_decay']
            #     #     momentum = group['momentum']
            #     #     dampening = group['dampening']
            #     #     nesterov = group['nesterov']
            #     #
            #     #     for p in group['params']:
            #     #         if p.grad is None:
            #     #             continue
            #     #         d_p = p.grad.data
            #     #         if weight_decay != 0:
            #     #             d_p.add_(weight_decay, p.data)
            #     #         # Apply learning rate
            #     #
            #     #         if momentum != 0:
            #     #             param_state = optimizer.state[p]
            #     #             if 'momentum_buffer' not in param_state:
            #     #                 buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            #     #                 buf.mul_(momentum).add_(d_p)
            #     #             else:
            #     #                 buf = param_state['momentum_buffer']
            #     #                 buf.mul_(momentum).add_(1 - dampening, d_p)
            #     #             if nesterov:
            #     #                 d_p = d_p.add(momentum, buf)
            #     #             else:
            #     #                 d_p = buf
            #     #         # d_p.mul_(group['lr'])
            #     #         p.data.add_(d_p, alpha=-0.1)
            #     #
            #     # # for group in optimizer.param_groups:
            #     # #     params_with_grad = []
            #     # #     d_p_list = []
            #     # #     momentum_buffer_list = []
            #     # #     has_sparse_grad = False
            #     # #
            #     # #     for p in group['params']:
            #     # #         if p.grad is not None:
            #     # #             params_with_grad.append(p)
            #     # #             d_p_list.append(p.grad.data)
            #     # #             if p.grad.is_sparse:
            #     # #                 has_sparse_grad = True
            #     # #
            #     # #             state = optimizer.state[p]
            #     # #             if 'momentum_buffer' not in state:
            #     # #                 momentum_buffer_list.append(None)
            #     # #             else:
            #     # #                 momentum_buffer_list.append(state['momentum_buffer'])
            #     # #
            #     # #     for i, param in enumerate(params_with_grad):
            #     # #
            #     # #         d_p = d_p_list[i]
            #     # #         alpha = group['lr'] if group['maximize'] else -group['lr']
            #     # #         param.data.add_(d_p, alpha=-0.1)
            #     #
            #     #
            #     # # optimizer.step()
            #
            #     multi_task_model.eval()
            #     del grad
            #     del grad_task_dependent
            #     del data_task_dependent
            #     del grad_all
            #     del train_loss_tmp
            #     del weight
            #     del loss_baseline
            #
            #     # with torch.no_grad():
            #     #     inner_train_pred, inner_logsigma = multi_task_model(train_data)
            #     #
            #     #     inner_train_loss = [model_fit(inner_train_pred[0], train_label, 'semantic'),
            #     #                         model_fit(inner_train_pred[1], train_depth, 'depth'),
            #     #                         model_fit(inner_train_pred[2], train_normal, 'normal')]
            #     #
            #     #     diff = [best_inner_loss[i] - inner_train_loss[i] for i in range(3)]
            #     #     print(f'Difference: {diff}!')
            #     #     del inner_train_pred
            #     #     del inner_logsigma
            #     #     del inner_train_loss
            #     #     del diff
            # elif method == 'cagrad_check':
            #     for i in range(3):
            #         if i < 3:
            #             train_loss_tmp[i].backward(retain_graph=True)
            #         else:
            #             train_loss_tmp[i].backward()
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #     g = cagrad_check(grads, alpha, rescale=1, degree=opt.degree)
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     del g
            elif method == "cagrad":
                for i in range(3):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                g3 = grads[:, 2].data.clone()

                angle_all.append(torch.stack([cos(g1, g2), cos(g1, g3), cos(g2, g3)]).view(-1))

                g = cagrad(grads, alpha, rescale=1)

                t1_angle = cos(g, g1)
                t2_angle = cos(g, g2)
                t3_angle = cos(g, g3)
                # angle = cos(g1, g2)

                log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item(),
                           't3_angle': t3_angle.item()}
                wandb.log(log_dict)

                avg_t1_angle.add(t1_angle.item())
                avg_t2_angle.add(t2_angle.item())
                avg_t3_angle.add(t3_angle.item())

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
                del g1
                del g2
                del g3
            # elif method == "cagrad_drop":
            #     for i in range(3):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     g3 = grads[:, 2].data.clone()
            #
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     t3_angle = cos(g, g3)
            #     # angle = cos(g1, g2)
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_t3_angle.add(t3_angle.item())
            #
            #     # avg_angle.add(angle.item())
            #
            #     running_avg_1.add(t1_angle.item())
            #     running_avg_2.add(t2_angle.item())
            #     running_avg_3.add(t3_angle.item())
            #
            #
            #     log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item(),
            #                't3_angle': t3_angle.item()}
            #     wandb.log(log_dict)
            #
            #     if index > 50:
            #         std_1 = running_avg_1.std()
            #         std_2 = running_avg_2.std()
            #         std_3 = running_avg_3.std()
            #
            #         avg_1 = running_avg_1.item()
            #         avg_2 = running_avg_2.item()
            #         avg_3 = running_avg_3.item()
            #
            #         omega1 = 1 - std_1 / avg_1 * opt.omega
            #         omega2 = 1 - std_2 / avg_2 * opt.omega
            #         omega3 = 1 - std_3 / avg_3 * opt.omega
            #
            #         avg_1_bound = avg_1 * omega1
            #         avg_2_bound = avg_2 * omega2
            #         avg_3_bound = avg_3 * omega3
            #
            #         changed = False
            #         if t1_angle.item() < avg_1_bound:
            #             train_loss_tmp[0] = -train_loss_tmp[0]
            #             changed = True
            #             log_dict = {'inner_index': inner_index, 't1_opposite': 1}
            #             wandb.log(log_dict)
            #         else:
            #             log_dict = {'inner_index': inner_index, 't1_opposite': 0}
            #             wandb.log(log_dict)
            #
            #         if t2_angle.item() < avg_2_bound:
            #             train_loss_tmp[1] = -train_loss_tmp[1]
            #             changed = True
            #             log_dict = {'inner_index': inner_index, 't2_opposite': 1}
            #             wandb.log(log_dict)
            #         else:
            #             log_dict = {'inner_index': inner_index, 't2_opposite': 0}
            #             wandb.log(log_dict)
            #
            #         if t3_angle.item() < avg_3_bound:
            #             train_loss_tmp[2] = -train_loss_tmp[2]
            #             changed = True
            #             log_dict = {'inner_index': inner_index, 't3_opposite': 1}
            #             wandb.log(log_dict)
            #         else:
            #             log_dict = {'inner_index': inner_index, 't3_opposite': 0}
            #             wandb.log(log_dict)
            #
            #         if changed:
            #             optimizer.zero_grad()
            #             del g
            #             for i in range(3):
            #                 if i < 2:
            #                     train_loss_tmp[i].backward(retain_graph=True)
            #                 else:
            #                     train_loss_tmp[i].backward()
            #                 grad2vec(multi_task_model, grads, grad_dims, i)
            #                 multi_task_model.zero_grad_shared_modules()
            #             g = cagrad(grads, alpha, rescale=1)
            #
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     del g
            #     del g1
            #     del g2
            #     del g3
            # elif method == 'cagrad_drop2':
            #     for i in range(3):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     g3 = grads[:, 2].data.clone()
            #
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     t3_angle = cos(g, g3)
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_t3_angle.add(t3_angle.item())
            #
            #     # avg_angle.add(angle.item())
            #
            #     running_avg_1.add(t1_angle.item())
            #     running_avg_2.add(t2_angle.item())
            #     running_avg_3.add(t3_angle.item())
            #
            #
            #     if index > 1:
            #         std_1 = running_avg_1.std()
            #         std_2 = running_avg_2.std()
            #         std_3 = running_avg_3.std()
            #
            #         avg_1 = running_avg_1.item()
            #         avg_2 = running_avg_2.item()
            #         avg_3 = running_avg_3.item()
            #
            #         omega1 = 1 + std_1 / avg_1 * opt.omega
            #         omega2 = 1 + std_2 / avg_2 * opt.omega
            #         omega3 = 1 + std_3 / avg_3 * opt.omega
            #
            #         avg_1_bound = avg_1 * omega1
            #         avg_2_bound = avg_2 * omega2
            #         avg_3_bound = avg_3 * omega3
            #
            #         total_cnt += 1
            #         if t1_angle.item() > avg_1_bound and t2_angle.item() > avg_2_bound and t3_angle.item() > avg_3_bound:
            #             cnt += 1
            #             overwrite_grad(multi_task_model, g, grad_dims)
            #             optimizer.step()
            #             del g
            #             del g1
            #             del g2
            #         else:
            #             optimizer.zero_grad()
            #             del g
            #             del g1
            #             del g2
            #     else:
            #         overwrite_grad(multi_task_model, g, grad_dims)
            #         optimizer.step()
            # elif method == 'cagrad_fix':
            #     grad_all = []
            #     n_tasks = 3
            #     for i in range(n_tasks):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     g3 = grads[:, 2].data.clone()
            #
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     t3_angle = cos(g, g3)
            #     # angle = cos(g1, g2)
            #
            #     log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item(),
            #                 't3_angle': t3_angle.item()}
            #     wandb.log(log_dict)
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_t3_angle.add(t3_angle.item())
            #
            #
            #     if index > opt.start_epoch:
            #         target_g_list = reshape_grad(g, grad_dims)
            #
            #         length = len(grad_all[0])
            #         flag2 = [True] * length
            #         freeze_pos = []
            #         cnt2 = 0
            #
            #         pair_grad = []
            #         for i in range(length):
            #             temp = []
            #             for j in range(n_tasks):
            #                 temp.append(grad_all[j][i])
            #             temp = torch.stack(temp)
            #             pair_grad.append(temp)
            #
            #         # cos_saved = []
            #         for i, pair in enumerate(pair_grad):
            #             cos_all = pair_cos_with_d(pair, target_g_list[i])
            #             ans2 = ((cos_all > opt.freeze_level).sum() == cos_all.size(0))
            #             if not ans2:
            #                 flag2[i] = False
            #                 freeze_pos.append(i)
            #                 cnt2 += 1
            #
            #             # cos_saved.append(cos_all)
            #
            #         # all_cos_value.append(cos_saved)
            #
            #         rate = 1 - cnt2 / length
            #         wandb.log({'rate2': rate})
            #
            #         if rate < 1:
            #             if rate > 0.00001:
            #                 freeze_network(multi_task_model, flag2)
            #                 g1 = remove_layer_gradient(grad_all[0], freeze_pos)
            #                 g2 = remove_layer_gradient(grad_all[1], freeze_pos)
            #                 g3 = remove_layer_gradient(grad_all[2], freeze_pos)
            #                 grads = [g1, g2, g3]
            #                 grads = torch.stack(grads).t()
            #                 g = cagrad(grads, alpha, rescale=1)
            #
            #                 t1_angle = cos(g, g1)
            #                 t2_angle = cos(g, g2)
            #                 t3_angle = cos(g, g3)
            #
            #                 wandb.log({'t1_angle_after': t1_angle.item()})
            #                 wandb.log({'t2_angle_after': t2_angle.item()})
            #                 wandb.log({'t3_angle_after': t3_angle.item()})
            #             else:
            #                 g = torch.zeros_like(g)
            #
            #     # grads_ta_all.append(g.data.clone())
            #     overwrite_grad_with_freeze_level(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     unfreeze_network(multi_task_model)
            #
            #     del g
            #     del g1
            #     del g2
            # elif method == 'cagrad_fix_layer':
            #     grad_all = []
            #     n_tasks = 3
            #     for i in range(n_tasks):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad = split_layer(grad_list=grad, name_dict=layer_dict)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     g3 = grads[:, 2].data.clone()
            #
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     t3_angle = cos(g, g3)
            #
            #     log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item(),
            #                 't3_angle': t3_angle.item()}
            #
            #     wandb.log(log_dict)
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_t3_angle.add(t3_angle.item())
            #
            #     if index > opt.start_epoch:
            #         target_g_list = reshape_grad(g, grad_dims)
            #         target_g_list = split_layer(target_g_list, name_dict=layer_dict)
            #
            #         length = len(grad_all[0])
            #         flag2 = [True] * len(grad_dims)
            #         freeze_pos = []
            #         cnt2 = 0
            #
            #         pair_grad = []
            #         for i in range(length):
            #             temp = []
            #             for j in range(n_tasks):
            #                 temp.append(grad_all[j][i])
            #             temp = torch.stack(temp)
            #             pair_grad.append(temp)
            #
            #         # cos_saved = []
            #         for i, pair in enumerate(pair_grad):
            #             diff_list = LFT_task_cos_diff(grad_all=grad_all, grad_target=target_g_list, n_tasks=n_tasks,
            #                                           pos=i)
            #             task_diff[layer_name[i]].append(diff_list)
            #
            #             layer_wise_cos = pair_cos_with_d(pair, target_g_list[i]).cpu()
            #             # layer_wise_dot = pair_dot_with_d(pair, target_g_list[i]).cpu()
            #             layer_wise_task_cos = pair_cos(pair).cpu()
            #             # layer_wise_task_dot = pair_dot(pair).cpu()
            #
            #             layer_wise_angle[layer_name[i]].append(layer_wise_cos)
            #             # layer_wise_dot_value[layer_name[i]].append(layer_wise_dot)
            #             layer_wise_task_angle[layer_name[i]].append(layer_wise_task_cos)
            #             # layer_wise_task_dot_value[layer_name[i]].append(layer_wise_task_dot)
            #
            #
            #             # cos_all = pair_cos_with_d(pair, target_g_list[i])
            #             # ans2 = ((cos_all > opt.freeze_level).sum() == cos_all.size(0))
            #             # if not ans2:
            #             #     pos = layer_dict[layer_name[i]]
            #             #     conflict_dict[layer_name[i]] += 1
            #             #     for j in pos:
            #             #         flag2[j] = False
            #             #     freeze_pos.append(i)
            #             #     cnt2 += 1
            #             #     wandb.run.summary[layer_name[i]] = conflict_dict[layer_name[i]]
            #
            #             # cos_saved.append(cos_all)
            #
            #         # all_cos_value.append(cos_saved)
            #
            #         # rate = 1 - cnt2 / length
            #         # wandb.log({'rate2': rate})
            #
            #         # if opt.ignore and rate < 1:
            #         #     if rate > 0.00001:
            #         #         freeze_network(multi_task_model, flag2)
            #         #         g1 = remove_layer_gradient(grad_all[0], freeze_pos)
            #         #         g2 = remove_layer_gradient(grad_all[1], freeze_pos)
            #         #         g3 = remove_layer_gradient(grad_all[2], freeze_pos)
            #         #         grads = [g1, g2, g3]
            #         #         grads = torch.stack(grads).t()
            #         #         g = cagrad(grads, alpha, rescale=1)
            #         #
            #         #         t1_angle = cos(g, g1)
            #         #         t2_angle = cos(g, g2)
            #         #         t3_angle = cos(g, g3)
            #         #
            #         #         wandb.log({'t1_angle_after': t1_angle.item()})
            #         #         wandb.log({'t2_angle_after': t2_angle.item()})
            #         #         wandb.log({'t3_angle_after': t3_angle.item()})
            #         #     else:
            #         #         g = torch.zeros_like(g)
            #
            #     # grads_ta_all.append(g.data.clone())
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #
            #     del g
            #     del g1
            #     del g2
            elif method == 'branch_layer':
                grad_all = []
                n_tasks = 3
                for i in range(n_tasks):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    grad = grad2vec_list(multi_task_model)
                    grad = split_layer(grad_list=grad, name_dict=layer_dict)
                    grad_all.append(grad)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                g3 = grads[:, 2].data.clone()

                if opt.method_sub == 'nothing':
                    g = (g1 + g2 + g3) / 3.0
                elif opt.method_sub == 'cagrad':
                    g = cagrad(grads, alpha, rescale=1)
                else:
                    raise ValueError(f'Error: {opt.method_sub}')

                t1_angle = cos(g, g1)
                t2_angle = cos(g, g2)
                t3_angle = cos(g, g3)

                log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item(),
                            't3_angle': t3_angle.item()}

                wandb.log(log_dict)
                avg_t1_angle.add(t1_angle.item())
                avg_t2_angle.add(t2_angle.item())
                avg_t3_angle.add(t3_angle.item())

                target_g_list = reshape_grad(g, grad_dims)
                target_g_list = split_layer(target_g_list, name_dict=layer_dict)

                length = len(grad_all[0])

                pair_grad = []
                for i in range(length):
                    temp = []
                    for j in range(n_tasks):
                        temp.append(grad_all[j][i])
                    temp = torch.stack(temp)
                    pair_grad.append(temp)

                # cos_saved = []
                if index < opt.start_epoch:
                    # for i in range(length):
                    #     task_diff_list = LFT_task_cos_diff(grad_all=grad_all, grad_target=target_g_list, n_tasks=n_tasks,
                    #                                    pos=i)
                    #     diff_list = LFT_cos_diff(grad_all=grad_all, n_tasks=n_tasks, pos=i)
                    #     task_diff[layer_name[i]].append(task_diff_list)
                    #     diff[layer_name[i]].append(diff_list)

                    for i, pair in enumerate(pair_grad):
                        layer_wise_task_cos = pair_cos_with_d(pair, target_g_list[i]).cpu()
                        # layer_wise_dot = pair_dot_with_d(pair, target_g_list[i]).cpu()
                        layer_wise_cos = pair_cos(pair).cpu()
                        # layer_wise_task_dot = pair_dot(pair).cpu()

                        layer_wise_angle[layer_name[i]].append(layer_wise_cos)
                        # layer_wise_dot_value[layer_name[i]].append(layer_wise_dot)
                        layer_wise_task_angle[layer_name[i]].append(layer_wise_task_cos)
                        # layer_wise_task_dot_value[layer_name[i]].append(layer_wise_task_dot)

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()

                if index == opt.start_epoch and turning_cnt <= 0:
                    turning_cnt += 1
                    branch_layers = topK_prob(task_diff, topK=opt.topK)
                    # branch_layers_2 = topK_value(task_diff, topK=opt.topK)
                    multi_task_model.turn(task_branches=branch_layers)

                    grad_dims = []
                    for key, param in multi_task_model.shared_parameters().items():
                        grad_dims.append(param.data.numel())

                    layer_dict = get_layer_dict(multi_task_model)
                    layer_name = list(layer_dict.keys())

                    # turning_cnt += 1
                    # if opt.branch_mode == 'angle':
                    #     branch_layers = quantile_topK(layer_wise_angle, q=opt.q, topK=opt.topK)
                    #     multi_task_model.turn(task_branches=branch_layers)
                    # elif opt.branch_mode == 'task_angle':
                    #     branch_layers = quantile_topK(layer_wise_task_angle, q=opt.q, topK=opt.topK)
                    #     multi_task_model.turn(task_branches=branch_layers)
                    #
                    # grad_dims = []
                    # for key, param in multi_task_model.shared_parameters().items():
                    #     grad_dims.append(param.data.numel())
                    #
                    # layer_dict = get_layer_dict(multi_task_model)
                    # layer_name = list(layer_dict.keys())

                del g
                del g1
                del g2
                del g3
            # elif method == 'cagrad_fix3':
            #     grad_all = []
            #     n_tasks = 3
            #     for i in range(n_tasks):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     g3 = grads[:, 2].data.clone()
            #
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     t3_angle = cos(g, g3)
            #     # angle = cos(g1, g2)
            #
            #     log_dict = {'inner_index': inner_index, 't1_angle': t1_angle.item(), 't2_angle': t2_angle.item(),
            #                 't3_angle': t3_angle.item()}
            #     wandb.log(log_dict)
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_t3_angle.add(t3_angle.item())
            #
            #
            #     if index > opt.start_epoch:
            #         target_g_list = reshape_grad(g, grad_dims)
            #
            #         length = len(grad_all[0])
            #         flag2 = [True] * length
            #         freeze_pos = []
            #         cnt2 = 0
            #         n_freeze_para = int(opt.freeze_level * length)
            #
            #         test = LFT_min_cos(grad_all, target_g_list, n_tasks, pos=[])
            #
            #         score = []
            #         for i in range(length):
            #             value = LFT_min_cos(grad_all, target_g_list, n_tasks, pos=i)
            #             score.append(value)
            #
            #         score = torch.stack(score).view(-1)
            #         max_value, max_pos = torch.topk(score, n_freeze_para, largest=True)
            #
            #         test2 = LFT_min_cos(grad_all, target_g_list, n_tasks, pos=max_pos.tolist())
            #
            #         for pos in max_pos:
            #             flag2[pos.item()] = False
            #
            #         wandb.log({'former': test.item()})
            #         wandb.log({'After': test2.item()})
            #
            #         freeze_network(multi_task_model, flag2)
            #
            #         g1 = remove_layer_gradient(grad_all[0], max_pos.tolist())
            #         g2 = remove_layer_gradient(grad_all[1], max_pos.tolist())
            #         g3 = remove_layer_gradient(grad_all[2], max_pos.tolist())
            #         grads = [g1, g2, g3]
            #         grads = torch.stack(grads).t()
            #         g = cagrad(grads, alpha, rescale=1)
            #
            #     # grads_ta_all.append(g.data.clone())
            #     overwrite_grad_with_freeze_level(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     unfreeze_network(multi_task_model)
            #
            #     del g
            #     del g1
            #     del g2
            # elif method == "cagrad_weight":
            #     for i in range(3):
            #         if i < 3:
            #             train_loss_tmp[i].backward(retain_graph=True)
            #         else:
            #             train_loss_tmp[i].backward()
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #     g = cagrad_weight(grads, alpha, rescale=1)
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     del g
            # elif method == 'cagrad_SVD':
            #     for i in range(3):
            #         if i < 3:
            #             train_loss_tmp[i].backward(retain_graph=True)
            #         else:
            #             train_loss_tmp[i].backward()
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #     weight = cagrad_weight(grads, alpha, rescale=1)
            #     grads = grads.t() * weight.view(3, -1)
            #     g = grads.mean(dim=0)
            #     U, S, Vh = torch.linalg.svd(grads, full_matrices=False)
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     print('---------------------------------------------------------------------------')
            #     print(U)
            #     print(S)
            #     print(Vh)
            #     del g
            #     del weight
            #     del U
            #     del S
            #     del Vh
            elif method == 'nothing':
                optimizer.zero_grad()
                loss = sum(train_loss_tmp)
                loss.backward()
                optimizer.step()

            elif method == 'nothing_v2':
                optimizer.zero_grad()
                # loss = sum(train_loss_tmp)

                for i in range(3):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                g3 = grads[:, 2].data.clone()

                angle_all.append(torch.stack([cos(g1, g2), cos(g1, g3), cos(g2, g3)]).view(-1))

                g = (g1 + g2 + g3) / 3.0
                # if opt.weight == 'equal' or opt.weight == 'dwa':
                #     loss = sum([lambda_weight[i, index] * train_loss[i] for i in range(3)])
                #     # loss = sum([w[i] * train_loss[i] for i in range(3)])
                # else:
                #     loss = sum(1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2 for i in range(3))

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
            # elif method == "check":
            #     for i in range(3):
            #         if i < 3:
            #             train_loss_tmp[i].backward(retain_graph=True)
            #         else:
            #
            #             train_loss_tmp[i].backward()
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     ans = checkgrad(grads, degree)
            #     times_conflict += ans
            #
            #     if not ans:
            #         if opt.weight == 'dwa':
            #             loss = sum([lambda_weight[i, index] * train_loss[i] for i in range(3)])
            #             # loss = sum([w[i] * train_loss[i] for i in range(3)])
            #         elif opt.weight == 'equal':
            #             loss = sum([train_loss[i] / 3.0 for i in range(3)])
            #         else:
            #             loss = sum(1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2 for i in range(3))
            #
            #         loss.backward()
            #         optimizer.step()

            # accumulate label prediction for every pixel in training images
            conf_mat.update(train_pred[0].argmax(1).flatten().cpu(), train_label.flatten().cpu())

            cost[0] = train_loss[0].item()
            cost[3] = train_loss[1].item()
            cost[4], cost[5] = depth_error(train_pred[1], train_depth)
            cost[6] = train_loss[2].item()
            cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)
            avg_cost[index, :12] += cost[:12] / train_batch

            # print(f'Epoch: {index}, Batch: {k}, TR0: {train_loss[0].item():.4f}, TR1: {train_loss[1].item():.4f},  TR2: {train_loss[2].item():.4f}')

            del data
            del logsigma
            del train_data
            del train_depth
            del train_label
            del train_loss
            del train_normal
            del train_loss_tmp
            del train_pred
            del grads
            inner_index += 1
            # torch.use_deterministic_algorithms(True)

        # compute mIoU and acc
        avg_cost[index, 1:3] = conf_mat.get_metrics()

        if total_cnt > 0:
            wandb.run.summary['update_rate'] = 1.0 * cnt / total_cnt

        if (index + 1) % eval_freq == 0:
            # evaluating test data
            multi_task_model.eval()
            conf_mat = ConfMatrix(multi_task_model.class_nb)
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    test_data, test_label, test_depth, test_normal, indices = data
                    test_data, test_label = test_data.cuda(), test_label.long().cuda()
                    test_depth, test_normal = test_depth.cuda(), test_normal.cuda()
                    test_pred, _ = multi_task_model(test_data)
                    test_loss = [model_fit(test_pred[0], test_label, 'semantic'),
                                 model_fit(test_pred[1], test_depth, 'depth'),
                                 model_fit(test_pred[2], test_normal, 'normal')]

                    conf_mat.update(test_pred[0].argmax(1).flatten().cpu(), test_label.flatten().cpu())

                    cost[12] = test_loss[0].item()
                    cost[15] = test_loss[1].item()
                    cost[16], cost[17] = depth_error(test_pred[1], test_depth)
                    cost[18] = test_loss[2].item()
                    cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(test_pred[2], test_normal)
                    avg_cost[index, 12:] += cost[12:] / test_batch

                # compute mIoU and acc
                avg_cost[index, 13:15] = conf_mat.get_metrics()

            # if method == "mean":
            #     torch.save(torch.Tensor(neg_trace), "trace.pt")
            #
            # if "debug" in method:
            #     torch.save(torch.Tensor(obj_trace), f"{method}_obj.pt")

            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle, avg_t3_angle=avg_t3_angle, mode='val')
        else:
            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle, avg_t3_angle=avg_t3_angle, mode='train')


        print(f'Epoch:{index} | ETA:{timer.measure()}/{timer.measure((index + 1) / epoch)} | lr: {scheduler.get_lr()}')

        scheduler.step()

        # saved_dict = {'cos': layer_wise_angle, 'dot': layer_wise_dot_value,
        #               'task_cos': layer_wise_task_angle, 'task_dot': layer_wise_task_dot_value}



        if (index + 1) == epoch and method == 'branch_layer':
            # saved_dict = {'task_diff': task_diff}
            # torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}_{epoch}_task_diff_data_nyuv.pt')
            # saved_dict = {'diff': diff}
            # torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}_{epoch}_diff_data_nyuv.pt')
            saved_dict = {'cos': layer_wise_angle}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}_{epoch}_lw_cos_nyuv.pt')
            saved_dict = {'task_cos': layer_wise_task_angle}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}_{epoch}_lw_task_cos_nyuv.pt')

        torch.save(angle_all,
                   f'./saved/{opt.seed}{opt.base_model}{opt.topK}{opt.method}{opt.optimizer}_{epoch}_angle.pt')

        # if "cagrad" in method:
        #     torch.save(multi_task_model.state_dict(), f"models/{method}-{opt.weight}-{alpha}-{seed}.pt")
        # else:
        #     torch.save(multi_task_model.state_dict(), f"models/{method}-{opt.weight}-{seed}.pt")


def train_rotograd(train_loader, test_loader, model,
                           optimizer_main, scheduler_main, optimizer_R, scheduler_R, opt):
    epoch = opt.n_epoch
    eval_freq = opt.eval_freq

    train_batch = len(train_loader)
    test_batch = len(test_loader)
    avg_cost = np.zeros([epoch, 24], dtype=np.float32)

    timer = Timer()

    for index in range(epoch):

        # iteration for all batches
        model.train()

        avg_t1_angle = Averager()
        avg_t2_angle = Averager()
        avg_t3_angle = Averager()
        avg_angle = Averager()

        conf_mat = ConfMatrix(13)
        cost = np.zeros(24, dtype=np.float32)

        for k, data in enumerate(train_loader):
            train_data, train_label, train_depth, train_normal, indices = data
            train_data, train_label = train_data.cuda(), train_label.long().cuda()
            train_depth, train_normal = train_depth.cuda(), train_normal.cuda()

            train_pred = model(train_data)

            # if opt.top_q <= 0:
            train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
                          model_fit(train_pred[1], train_depth, 'depth'),
                          model_fit(train_pred[2], train_normal, 'normal')]

            optimizer_main.zero_grad()
            optimizer_R.zero_grad()

            model.backward(train_loss)

            optimizer_main.step()
            optimizer_R.step()

        conf_mat.update(train_pred[0].argmax(1).flatten().cpu(), train_label.flatten().cpu())


        cost[0] = train_loss[0].item()
        cost[3] = train_loss[1].item()
        cost[4], cost[5] = depth_error(train_pred[1], train_depth)
        cost[6] = train_loss[2].item()
        cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)
        avg_cost[index, :12] += cost[:12] / train_batch

        if (index + 1) % eval_freq == 0:
            # evaluating test data
            model.eval()
            conf_mat = ConfMatrix(13)
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    test_data, test_label, test_depth, test_normal, indices = data
                    test_data, test_label = test_data.cuda(), test_label.long().cuda()
                    test_depth, test_normal = test_depth.cuda(), test_normal.cuda()
                    test_pred = model(test_data)
                    test_loss = [model_fit(test_pred[0], test_label, 'semantic'),
                                 model_fit(test_pred[1], test_depth, 'depth'),
                                 model_fit(test_pred[2], test_normal, 'normal')]

                    conf_mat.update(test_pred[0].argmax(1).flatten().cpu(), test_label.flatten().cpu())

                    cost[12] = test_loss[0].item()
                    cost[15] = test_loss[1].item()
                    cost[16], cost[17] = depth_error(test_pred[1], test_depth)
                    cost[18] = test_loss[2].item()
                    cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(test_pred[2], test_normal)
                    avg_cost[index, 12:] += cost[12:] / test_batch

                # compute mIoU and acc
                avg_cost[index, 13:15] = conf_mat.get_metrics()

            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle,
                          avg_t3_angle=avg_t3_angle, mode='val')
        else:
            print_details(avg_cost=avg_cost, index=index, avg_t1_angle=avg_t1_angle, avg_t2_angle=avg_t2_angle,
                          avg_t3_angle=avg_t3_angle, mode='train')

        print(f'Epoch:{index} | ETA:{timer.measure()}/{timer.measure((index + 1) / epoch)}')

        scheduler_R.step()
        scheduler_main.step()