import math

import numpy as np
import time
import torch

import torch.nn.functional as F
import gc
import random
import wandb

# from sklearn import random_projection
from torch.nn.modules.loss import CrossEntropyLoss
from copy import deepcopy
from min_norm_solvers import MinNormSolver
from scipy.optimize import minimize, Bounds, minimize_scalar
from collections import OrderedDict
# from random_projection import random_projection_matrix
loc = locals()

class Timer():
    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return f'{x / 3600:.1f}h'
        if x >= 60:
            return f'{round(x / 60)}m'
        return f'{x}s'

class Averager():

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.data = []

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def add(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.data.append(val)

    def item(self):
        return self.avg

    def obtain_data(self):
        return self.data

    def __len__(self):
        return len(self.data)

class Running_Avg():
    def __init__(self, max_cnt):
        self.max_cnt = max_cnt
        self.data = []
        self.cnt = 0

    def add(self, val):
        self.data.append(val)
        self.cnt += 1

    def item(self):
        if self.cnt < self.max_cnt:
            d = self.data
        else:
            d = self.data[-self.max_cnt:]
        return sum(d) / len(d)

    def std(self):
        if self.cnt < self.max_cnt:
            d = self.data
        else:
            d = self.data[-self.max_cnt:]

        return np.std(d)

def calculate_score(score):
    n_epoch = score.size(0)
    score = torch.triu(score, diagonal=1)
    score = score.sum() / (n_epoch * (n_epoch + 1) / 2)
    return score

def topK_prob(task_diff, topK):
    diff_dict = OrderedDict()
    for key, value in task_diff.items():
        l = torch.stack(value)
        n_task = l.size(1)
        gt = (l > 0)
        gt = gt.sum(dim=1)
        gt = (gt == n_task)
        prob = gt.sum() / gt.size(0)
        diff_dict[key] = prob

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1], reverse=True)}

    output = list(diff_dict.keys())[:topK]
    #
    # print('-------------topK_prob---------------')
    # print(diff_dict)

    return output

def topK_value(task_diff, topK):
    diff_dict = OrderedDict()

    for key, value in task_diff.items():
        l = torch.stack(value)
        score = l.mean()
        diff_dict[key] = score

    diff_dict = {k: v for k, v in sorted(diff_dict.items(), key=lambda item: item[1])}

    output = list(diff_dict.keys())[:topK]

    # print('-------------topK_value---------------')
    # print(diff_dict)

    return output

def cos(t1, t2):
    t1 = F.normalize(t1, dim=0)
    t2 = F.normalize(t2, dim=0)

    dot = (t1 * t2).sum(dim=0)

    return dot

def dot(t1, t2):
    dot = (t1 * t2).sum(dim=0)
    return dot

def pair_cos_with_d(pair, d):
    length = pair.size(0)

    dot_value = [cos(d, pair[i]) for i in range(length)]

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_dot_with_d(pair, d):
    length = pair.size(0)

    dot_value = [dot(d, pair[i]) for i in range(length)]

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_cos(pair):
    length = pair.size(0)

    dot_value = []
    for i in range(length - 1):
        for j in range(i + 1, length):
           dot_value.append(cos(pair[i], pair[j]))

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

def pair_dot(pair):
    length = pair.size(0)

    dot_value = []
    for i in range(length - 1):
        for j in range(i + 1, length):
            dot_value.append(dot(pair[i], pair[j]))

    dot_value = torch.stack(dot_value).view(-1)
    return dot_value

# def pari_cos(pair):
#     normalize_pair = F.normalize(pair, dim=1)
#     dot_value = torch.mm(normalize_pair, normalize_pair.T)
#     dot_value = torch.triu(dot_value, diagonal=1).view(-1)
#     dot_value = dot_value[dot_value.nonzero().detach()]
#     return dot_value
#
# def pair_cos2(pair):
#     mean = torch.mean(pair, dim=0)
#     dot_value = [cos(mean, pair[0]), cos(mean, pair[1])]
#     dot_value = torch.stack(dot_value).view(-1)
#     return dot_value
#
# def pair_cos_with_d(pair, d):
#     dot_value = [cos(d, pair[0]), cos(d, pair[1])]
#     dot_value = torch.stack(dot_value).view(-1)
#     return dot_value

def remove_layer_gradient(grad_list, pos):
    if not isinstance(pos, list):
        pos = [pos]
    g_list = []
    for i, g in enumerate(grad_list):
        if i not in pos:
            g_list.append(g)

    g = torch.cat(g_list)
    return g

def LFT_min_cos(grad_all, n_tasks, pos):
    min_cos = 1000
    for i in range(n_tasks - 1):
        for j in range(i+1, n_tasks):
            g1 = remove_layer_gradient(grad_all[i], pos)
            g2 = remove_layer_gradient(grad_all[j], pos)
            cos_v = cos(g1, g2)
            if cos_v < min_cos:
                min_cos = cos_v
    return min_cos

def LFT_task_cos_diff(grad_all, grad_target, n_tasks, pos):
    diff_list = []
    for i in range(n_tasks):
        g_task_previous = torch.cat(grad_all[i])
        g_previous = torch.cat(grad_target)
        cos_v_previous = cos(g_task_previous, g_previous)

        g_task = remove_layer_gradient(grad_all[i], pos)
        g = remove_layer_gradient(grad_target, pos)
        cos_v = cos(g_task, g)

        diff = cos_v - cos_v_previous
        diff_list.append(diff)

    diff_list = torch.stack(diff_list).view(-1)

    return diff_list

def LFT_cos_diff(grad_all, n_tasks, pos):
    diff_list = []
    for i in range(n_tasks - 1):
        for j in range(i + 1, n_tasks):
            g_1 = torch.cat(grad_all[i])
            g_2 = torch.cat(grad_all[j])
            cos_v_previous = cos(g_1, g_2)

            g_1 = remove_layer_gradient(grad_all[i], pos)
            g_2 = remove_layer_gradient(grad_all[j], pos)

            cos_v = cos(g_1, g_2)

            diff = cos_v - cos_v_previous
            diff_list.append(diff)

    diff_list = torch.stack(diff_list).view(-1)

    return diff_list

def cagrad(grads, alpha=0.5, rescale=1, n_tasks=2):
    GG = grads.t().mm(grads).cpu()  # [num_tasks, num_tasks]
    g0_norm = (GG.mean() + 1e-8).sqrt()  # norm of the average gradient

    x_start = np.ones(n_tasks) / n_tasks
    bnds = tuple((0, 1) for x in x_start)
    cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})
    A = GG.numpy()
    b = x_start.copy()
    c = (alpha * g0_norm + 1e-8).item()

    def objfn(x):
        return (x.reshape(1, n_tasks).dot(A).dot(b.reshape(n_tasks, 1)) + c * np.sqrt(
            x.reshape(1, n_tasks).dot(A).dot(x.reshape(n_tasks, 1)) + 1e-8)).sum()

    res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
    w_cpu = res.x
    ww = torch.Tensor(w_cpu).to(grads.device)
    gw = (grads * ww.view(1, -1)).sum(1)
    gw_norm = gw.norm()
    lmbda = c / (gw_norm + 1e-8)
    g = grads.mean(1) + lmbda * gw
    if rescale == 0:
        return g
    elif rescale == 1:
        return g / (1 + alpha ** 2)
    else:
        return g / (1 + alpha)

def LFT_task_cos_diff_cagrad(grad_all, grad_target, n_tasks, pos, alpha):

    cos_v_previous_list = []

    for i in range(n_tasks):
        g_task_previous = torch.cat(grad_all[i])
        g_previous = torch.cat(grad_target)
        cos_v_previous = cos(g_task_previous, g_previous)
        cos_v_previous_list.append(cos_v_previous)

    g_all = []

    for i in range(n_tasks):
        g_task = remove_layer_gradient(grad_all[i], pos)
        g_all.append(g_task)

    g_all = torch.stack(g_all).T
    g = cagrad(g_all, alpha, rescale=1)
    g1 = g_all[:, 0]
    g2 = g_all[:, 1]

    diff_list = [cos(g, g1) - cos_v_previous_list[0], cos(g, g2) - cos_v_previous_list[1]]
    diff_list = torch.stack(diff_list).view(-1)

    return diff_list

def LFT_min(grad_all, n_tasks, pos):
    min_cos = 1000
    for i in range(n_tasks - 1):
        for j in range(i + 1, n_tasks):
            g1 = remove_layer_gradient(grad_all[i], pos)
            g2 = remove_layer_gradient(grad_all[j], pos)
            dot_v = dot(g1, g2)
            if dot_v < min_cos:
                min_cos = dot_v

    return min_cos

''' ===== multi task MGD trainer ==== '''
def multi_task_mgd_trainer(train_loader, test_loader, multi_task_model,
                           optimizer, scheduler, opt):
    epoch = opt.n_epoch
    alpha = opt.alpha
    method = opt.method
    eval_freq = opt.eval_freq

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def graddrop(grads):
        P = 0.5 * (1. + grads.sum(1) / (grads.abs().sum(1) + 1e-8))
        U = torch.rand_like(grads[:, 0])
        M = P.gt(U).view(-1, 1) * grads.gt(0) + P.lt(U).view(-1, 1) * grads.lt(0)
        g = (grads * M.float()).mean(1)
        return g

    def mgd(grads):
        grads_cpu = grads.t().cpu()
        sol, min_norm = MinNormSolver.find_min_norm_element([
            grads_cpu[t] for t in range(grads.shape[-1])])
        w = torch.FloatTensor(sol).to(grads.device)
        g = grads.mm(w.view(-1, 1)).view(-1)
        return g

    def pcgrad(grads, rng):
        grad_vec = grads.t()
        num_tasks = 2

        shuffled_task_indices = np.zeros((num_tasks, num_tasks - 1), dtype=int)
        for i in range(num_tasks):
            task_indices = np.arange(num_tasks)
            task_indices[i] = task_indices[-1]
            shuffled_task_indices[i] = task_indices[:-1]
            rng.shuffle(shuffled_task_indices[i])
        shuffled_task_indices = shuffled_task_indices.T

        normalized_grad_vec = grad_vec / (
                grad_vec.norm(dim=1, keepdim=True) + 1e-8
        )  # num_tasks x dim
        modified_grad_vec = deepcopy(grad_vec)
        for task_indices in shuffled_task_indices:
            normalized_shuffled_grad = normalized_grad_vec[
                task_indices
            ]  # num_tasks x dim
            dot = (modified_grad_vec * normalized_shuffled_grad).sum(
                dim=1, keepdim=True
            )  # num_tasks x dim
            modified_grad_vec -= torch.clamp_max(dot, 0) * normalized_shuffled_grad
        g = modified_grad_vec.mean(dim=0)
        return g

    def cagrad(grads, alpha=0.5, rescale=1, n_tasks=2):
        GG = grads.t().mm(grads).cpu()  # [num_tasks, num_tasks]
        g0_norm = (GG.mean() + 1e-8).sqrt()  # norm of the average gradient

        x_start = np.ones(n_tasks) / n_tasks
        bnds = tuple((0, 1) for x in x_start)
        cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})
        A = GG.numpy()
        b = x_start.copy()
        c = (alpha * g0_norm + 1e-8).item()

        def objfn(x):
            return (x.reshape(1, n_tasks).dot(A).dot(b.reshape(n_tasks, 1)) + c * np.sqrt(
                x.reshape(1, n_tasks).dot(A).dot(x.reshape(n_tasks, 1)) + 1e-8)).sum()

        res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
        w_cpu = res.x
        ww = torch.Tensor(w_cpu).to(grads.device)
        gw = (grads * ww.view(1, -1)).sum(1)
        gw_norm = gw.norm()
        lmbda = c / (gw_norm + 1e-8)
        g = grads.mean(1) + lmbda * gw
        if rescale == 0:
            return g
        elif rescale == 1:
            return g / (1 + alpha ** 2)
        else:
            return g / (1 + alpha)

    def cagrad_weight(grads, alpha=0.5, rescale=1):
        GG = grads.t().mm(grads).cpu()  # [num_tasks, num_tasks]
        g0_norm = (GG.mean() + 1e-8).sqrt()  # norm of the average gradient

        x_start = np.ones(3) / 3
        bnds = tuple((0, 1) for x in x_start)
        cons = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)})
        A = GG.numpy()
        b = x_start.copy()
        c = (alpha * g0_norm + 1e-8).item()

        def objfn(x):
            return (x.reshape(1, 3).dot(A).dot(b.reshape(3, 1)) + c * np.sqrt(
                x.reshape(1, 3).dot(A).dot(x.reshape(3, 1)) + 1e-8)).sum()

        res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
        w_cpu = res.x
        ww = torch.Tensor(w_cpu).to(grads.device)
        gw = (grads * ww.view(1, -1)).sum(1)
        gw_norm = gw.norm()
        lmbda = c / (gw_norm + 1e-8)

        weight = torch.Tensor([1 / 3, 1 / 3, 1 / 3]).cuda()
        weight = weight + lmbda * ww

        if rescale == 0:
            return weight
        elif rescale == 1:
            weight = weight / (1 + alpha ** 2)
            return weight
        else:
            return weight / (1 + alpha)

    def grad2vec(m, grads, grad_dims, task):
        # store the gradients
        grads[:, task].fill_(0.0)
        cnt = 0
        for name, p in m.shared_parameters().items():
            grad = p.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone()
                beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
                en = sum(grad_dims[:cnt + 1])
                grads[beg:en, task].copy_(grad_cur.data.view(-1))
            cnt += 1

    def grad2vec_g(m, grads, grad_dims):
        grads[:].fill_(0.0)
        cnt = 0
        for name, p in m.shared_parameters().items():
                grad = p.grad
                if grad is not None:
                    grad_cur = grad.data.detach().clone()
                    beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
                    en = sum(grad_dims[:cnt + 1])
                    grads[beg:en].copy_(grad_cur.data.view(-1))
                cnt += 1

    def freeze_network(m, flags):
        cur = 0
        for name, param in m.shared_parameters().items():
            param.requires_grad = flags[cur]
            cur += 1

    def unfreeze_network(m):
        for name, param in m.shared_parameters().items():
            param.requires_grad = True

    def overwrite_grad(m, newgrad, grad_dims):
        newgrad = newgrad * 2  # to match the sum loss
        cnt = 0
        for name, param in m.shared_parameters().items():
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            this_grad = newgrad[beg: en].contiguous().view(param.data.size())
            param.grad = this_grad.data.clone()
            cnt += 1

    def overwrite_grad_with_freeze_level(m, newgrad, grad_dims):
        newgrad = newgrad * 2  # to match the sum loss
        beg = 0
        cnt = 0
        for name, param in m.shared_parameters().items():
            if param.requires_grad is True:
                end = beg + grad_dims[cnt]
                this_grad = newgrad[beg: end].contiguous().view(param.data.size())
                param.grad = this_grad.data.clone()
                beg = end
            cnt += 1

    def grad2vec_list(m):
        grad_list = []
        for name, param in m.shared_parameters().items():
            grad = param.grad
            if grad is not None:
                grad_cur = grad.data.detach().clone().view(-1)
                grad_list.append(grad_cur)
        return grad_list

    # def grad2vec_dict(m):
    #     grad_list = []
    #     name_list = []
    #     for mm in m.shared_modules():
    #         for name, p in mm.named_parameters():
    #             grad = p.grad
    #             if grad is not None:
    #                 grad_cur = grad.data.detach().clone().view(-1)
    #                 grad_list.append(grad_cur)
    #                 name_list.append(name)
    #     return grad_list, name_list

    def get_layer_dict(m):
        shared_parameters = m.shared_parameters()
        name_list = list(shared_parameters.keys())

        layer_dict = {}
        for i, name in enumerate(name_list):
            if '.weight' in name:
                name = name.replace('.weight', '')
            elif '.bias' in name:
                name = name.replace('.bias', '')

            if name not in layer_dict:
                layer_dict[name] = [i]
            else:
                layer_dict[name].append(i)

        return layer_dict

    def split_layer(grad_list, name_dict):
        grad_new = []
        for key, value in name_dict.items():
            grad = [grad_list[i] for i in value]
            grad = torch.cat(grad)
            grad_new.append(grad)

        return grad_new

    def reshape_grad(g, grad_dims):
        grad_all = []
        for i in range(len(grad_dims)):
            beg = 0 if i == 0 else sum(grad_dims[:i])
            en = sum(grad_dims[:i+1])
            this_grad = g[beg:en].data.detach().clone()
            grad_all.append(this_grad)
        return grad_all

    rng = np.random.default_rng()
    grad_dims = []

    for key, param in multi_task_model.shared_parameters().items():
        grad_dims.append(param.data.numel())

    layer_dict = get_layer_dict(multi_task_model)
    layer_name = list(layer_dict.keys())

    timer = Timer()
    ce_loss = CrossEntropyLoss()

    n_classes = 10

    running_avg_1 = Running_Avg(100)
    running_avg_2 = Running_Avg(100)
    running_avg = Running_Avg(10)

    total_cnt = 0
    t1_cnt = 0
    t2_cnt = 0
    t_cnt = 0

    layer_wise_angle = OrderedDict()
    layer_wise_dot_value = OrderedDict()
    layer_wise_task_angle = OrderedDict()
    layer_wise_task_dot_value = OrderedDict()
    task_diff = OrderedDict()
    diff = OrderedDict()

    for name in layer_name:
        layer_wise_angle[name] = []
        layer_wise_dot_value[name] = []
        layer_wise_task_angle[name] = []
        layer_wise_task_dot_value[name] = []
        task_diff[name] = []
        diff[name] = []

    angle_all = []

    for index in range(epoch):
        multi_task_model.train()

        avg_tr_loss_1 = Averager()
        avg_tr_acc_1 = Averager()
        avg_te_loss_1 = Averager()
        avg_te_acc_1 = Averager()

        avg_tr_loss_2 = Averager()
        avg_tr_acc_2 = Averager()
        avg_te_loss_2 = Averager()
        avg_te_acc_2 = Averager()

        avg_t1_angle = Averager()
        avg_t2_angle = Averager()
        avg_angle = Averager()

        wandb.define_metric("epoch")
        wandb.define_metric("Train_loss_t1_avg", step_metric="epoch")
        wandb.define_metric("Train_loss_t2_avg", step_metric="epoch")
        wandb.define_metric("Train_acc_t1", step_metric="epoch")
        wandb.define_metric("Train_acc_t2", step_metric="epoch")
        wandb.define_metric("avg_t1_angle", step_metric="epoch")
        wandb.define_metric("avg_t2_angle", step_metric="epoch")
        wandb.define_metric("avg_angle", step_metric="epoch")
        wandb.define_metric("Test_loss_t1", step_metric="epoch")
        wandb.define_metric("Test_loss_t2", step_metric="epoch")
        wandb.define_metric("Test_acc_t1", step_metric="epoch")
        wandb.define_metric("Test_acc_t2", step_metric="epoch")

        loss_list_1 = []
        loss_list_2 = []
        if method == 'test':
            grads_t1_all = [[] for i in range(n_classes)]
            grads_t2_all = [[] for i in range(n_classes)]
        else:
            grads_t1_all = []
            grads_t2_all = []
        grads_ta_all = []

        t1_loss = None
        t2_loss = None
        t_g = None

        stage1 = opt.stage1
        stage2 = opt.stage2
        stage = stage1

        for k, data in enumerate(train_loader):
            grads = torch.Tensor(sum(grad_dims), 2).cuda()
            x = data[0].cuda()
            ts = data[1].cuda()

            t1_pred, t2_pred = multi_task_model(x)
            n_samples = t1_pred.size(0)

            # if method == 'test':
            #     train_loss = []
            #     t1_s = ts[:, 0]
            #     t2_s = ts[:, 1]
            #     for i in range(n_classes):
            #         pos_t1 = (t1_s == i).nonzero().view(-1)
            #         t1_s_sub = t1_s.index_select(dim=0, index=pos_t1)
            #         t1_pred_sub = t1_pred.index_select(dim=0, index=pos_t1)
            #         #-------------------------------------
            #         pos_t2 = (t2_s == i).nonzero().view(-1)
            #         t2_s_sub = t2_s.index_select(dim=0, index=pos_t2)
            #         t2_pred_sub = t2_pred.index_select(dim=0, index=pos_t2)
            #         L = [ce_loss(t1_pred_sub, t1_s_sub), ce_loss(t2_pred_sub, t2_s_sub)]
            #         train_loss.append(L)

            train_loss_tmp = [ce_loss(t1_pred, ts[:,0]), ce_loss(t2_pred, ts[:,1])]

            wandb.log({'Train_loss_t1': train_loss_tmp[0].item()})
            wandb.log({'Train_loss_t2': train_loss_tmp[1].item()})

            avg_tr_loss_1.add(train_loss_tmp[0].item(), n_samples)
            avg_tr_loss_2.add(train_loss_tmp[1].item(), n_samples)

            # if index > 55:
            #     train_loss_tmp[0] = torch.abs(train_loss_tmp[0] - 0.11) + 0.11
            # if index > 45:

            # if opt.flood:
            #     # train_loss_tmp[1] = torch.abs(train_loss_tmp[1] - 0.26) + 0.26
            #     # train_loss_tmp[0] = torch.abs(train_loss_tmp[0] - 0.09) + 0.09
            #     if index == 50:
            #         loss_list_1.append(train_loss_tmp[0].item())
            #         loss_list_2.append(train_loss_tmp[1].item())
            #     if index > 50:
            #         train_loss_tmp[0] = torch.abs(train_loss_tmp[0] - L1_alpha) + L1_alpha
            #         train_loss_tmp[1] = torch.abs(train_loss_tmp[1] - L2_alpha) + L2_alpha


            # calculate the accuracy of prediction
            t1_estimate_labels = torch.argmax(t1_pred, dim=1)
            t2_estimate_labels = torch.argmax(t2_pred, dim=1)

            acc_t1 = (t1_estimate_labels == ts[:,0]).sum() / n_samples
            acc_t2 = (t2_estimate_labels == ts[:,1]).sum() / n_samples
            avg_tr_acc_1.add(acc_t1.item(), n_samples)
            avg_tr_acc_2.add(acc_t2.item(), n_samples)

            optimizer.zero_grad()

            if method == "graddrop":
                for i in range(2):
                    if i < 2:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                angle = cos(g1, g2)
                angle_all.append(angle)

                g = graddrop(grads)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            elif method == "mgd":
                for i in range(2):
                    if i < 2:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                angle = cos(g1, g2)
                angle_all.append(angle)

                g = mgd(grads)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            elif method == "pcgrad":
                for i in range(2):
                    if i < 2:
                        train_loss_tmp[i].backward(retain_graph=True)
                    else:
                        train_loss_tmp[i].backward()
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                angle = cos(g1, g2)
                angle_all.append(angle)

                g = pcgrad(grads, rng)
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
            elif method == "cagrad":
                for i in range(2):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                # grads_t1_all.append(g1)
                # grads_t2_all.append(g2)
                g = cagrad(grads, alpha, rescale=1)

                t1_angle = cos(g, g1)
                t2_angle = cos(g, g2)
                angle = cos(g1, g2)

                angle_all.append(angle)

                wandb.log({'t1_angle': t1_angle.item()})
                wandb.log({'t2_angle': t2_angle.item()})
                wandb.log({'angle': angle.item()})

                avg_t1_angle.add(t1_angle.item())
                avg_t2_angle.add(t2_angle.item())
                avg_angle.add(angle.item())

                # grads_ta_all.append(g.data.clone())
                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
                del g1
                del g2
            # elif method == "cagrad_drop":
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #
            #     running_avg_1.add(t1_angle.item())
            #     running_avg_2.add(t2_angle.item())
            #
            #     if index > 1:
            #         std_1 = running_avg_1.std()
            #         std_2 = running_avg_2.std()
            #         avg_1 = running_avg_1.item()
            #         avg_2 = running_avg_2.item()
            #
            #         omega1 = 1 - std_1 / avg_1 * opt.omega
            #         omega2 = 1 - std_2 / avg_2 * opt.omega
            #         avg_1_bound = avg_1 * omega1
            #         avg_2_bound = avg_2 * omega2
            #
            #         total_cnt += 1
            #
            #         changed = False
            #         if t1_angle.item() < avg_1_bound:
            #             train_loss_tmp[0] = -train_loss_tmp[0]
            #             changed = True
            #             t1_cnt += 1
            #             wandb.log({'t1_opposite': 1})
            #         else:
            #             wandb.log({'t1_opposite': 0})
            #
            #         if t2_angle.item() < avg_2_bound:
            #             train_loss_tmp[1] = -train_loss_tmp[1]
            #             changed = True
            #             t2_cnt += 1
            #             wandb.log({'t2_opposite': 1})
            #         else:
            #             wandb.log({'t2_opposite': 0})
            #
            #         if changed:
            #             optimizer.zero_grad()
            #             for i in range(2):
            #                 train_loss_tmp[i].backward(retain_graph=True)
            #                 grad2vec(multi_task_model, grads, grad_dims, i)
            #                 multi_task_model.zero_grad_shared_modules()
            #             g = cagrad(grads, alpha, rescale=1)
            #
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #
            #     del g
            #     del g1
            #     del g2
            # elif method == "cagrad_random_drop":
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #
            #     running_avg_1.add(t1_angle.item())
            #     running_avg_2.add(t2_angle.item())
            #
            #     if index > 1:
            #         random_rate = opt.random_rate
            #
            #         rs = torch.rand(2)
            #
            #         changed = False
            #         total_cnt += 1
            #
            #         if rs[0] < random_rate:
            #             train_loss_tmp[0] = -train_loss_tmp[0]
            #             changed = True
            #             t1_cnt += 1
            #             wandb.log({'t1_opposite': 1})
            #         else:
            #             wandb.log({'t1_opposite': 0})
            #
            #         if rs[1] < random_rate:
            #             train_loss_tmp[1] = -train_loss_tmp[1]
            #             changed = True
            #             t2_cnt += 1
            #             wandb.log({'t2_opposite': 1})
            #         else:
            #             wandb.log({'t2_opposite': 0})
            #
            #         if changed:
            #             optimizer.zero_grad()
            #             for i in range(2):
            #                 train_loss_tmp[i].backward(retain_graph=True)
            #                 grad2vec(multi_task_model, grads, grad_dims, i)
            #                 multi_task_model.zero_grad_shared_modules()
            #             g = cagrad(grads, alpha, rescale=1)
            #
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #
            #     del g
            #     del g1
            #     del g2
            # elif method == 'cagrad_flood':
            #     train_loss_tmp[0] = torch.abs(train_loss_tmp[0] - opt.t1_flood) + opt.t1_flood
            #     train_loss_tmp[1] = torch.abs(train_loss_tmp[1] - opt.t2_flood) + opt.t2_flood
            #
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     grads_t1_all.append(g1)
            #     grads_t2_all.append(g2)
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     del g
            #     del g1
            #     del g2
            # elif method == 'cagrad_fix':
            #     grad_all = []
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #
            #     target_g_list = reshape_grad(g, grad_dims)
            #
            #     length = len(grad_all[0])
            #     flag2 = [True] * length
            #     freeze_pos = []
            #     cnt2 = 0
            #     n_task = 2
            #
            #     if index > opt.start_epoch:
            #         pair_grad = []
            #         for i in range(length):
            #             temp = []
            #             for j in range(n_task):
            #                 temp.append(grad_all[j][i])
            #             temp = torch.stack(temp)
            #             pair_grad.append(temp)
            #
            #         cos_saved = []
            #         for i, pair in enumerate(pair_grad):
            #             cos_all = pair_cos_with_d(pair, target_g_list[i])
            #             ans2 = ((cos_all > opt.freeze_level).sum() == cos_all.size(0))
            #             if not ans2:
            #                 flag2[i] = False
            #                 freeze_pos.append(i)
            #                 cnt2 += 1
            #
            #             # cos_saved.append(cos_all)
            #
            #         # all_cos_value.append(cos_saved)
            #
            #         rate = 1 - cnt2 / length
            #         wandb.log({'rate2': rate})
            #
            #         if rate > 0.0001:
            #             freeze_network(multi_task_model, flag2)
            #             g1 = remove_layer_gradient(grad_all[0], freeze_pos)
            #             g2 = remove_layer_gradient(grad_all[1], freeze_pos)
            #             grads = [g1, g2]
            #             grads = torch.stack(grads).t()
            #             g = cagrad(grads, alpha, rescale=1)
            #
            #             t1_angle = cos(g, g1)
            #             t2_angle = cos(g, g2)
            #             angle = cos(g1, g2)
            #
            #             wandb.log({'t1_angle_after': t1_angle.item()})
            #             wandb.log({'t2_angle_after': t2_angle.item()})
            #             wandb.log({'angle_after': angle.item()})
            #         else:
            #             g = torch.zeros_like(g)
            #
            #     # grads_ta_all.append(g.data.clone())
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #
            #     del g
            #     del g1
            #     del g2
            # elif method == 'cagrad_fix_layer':
            #     grad_all = []
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad = split_layer(grad_list=grad, name_dict=layer_dict)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     g = cagrad(grads, alpha, rescale=1)
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     total_cnt += 1
            #     if t1_angle < 0:
            #         t1_cnt += 1
            #         wandb.run.summary['t1_freq'] = t1_cnt / total_cnt
            #     if t2_angle < 0:
            #         t2_cnt += 1
            #         wandb.run.summary['t2_freq'] = t2_cnt / total_cnt
            #     if angle < 0:
            #         t_cnt += 1
            #         wandb.run.summary['t_freq'] = t_cnt / total_cnt
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #
            #     target_g_list = reshape_grad(g, grad_dims)
            #     target_g_list = split_layer(target_g_list, name_dict=layer_dict)
            #
            #
            #     flag2 = [True] * len(grad_dims)
            #     freeze_pos = []
            #     cnt2 = 0
            #     n_task = 2
            #
            #     length = len(grad_all[0])
            #     if index > opt.start_epoch:
            #         pair_grad = []
            #         for i in range(length):
            #             temp = []
            #             for j in range(n_task):
            #                 temp.append(grad_all[j][i])
            #             temp = torch.stack(temp)
            #             pair_grad.append(temp)
            #
            #         for i, pair in enumerate(pair_grad):
            #
            #
            #             layer_wise_cos = pair_cos_with_d(pair, target_g_list[i]).cpu()
            #             # layer_wise_dot = pair_dot_with_d(pair, target_g_list[i]).cpu()
            #             layer_wise_task_cos = pair_cos(pair).cpu()
            #             # layer_wise_task_dot = pair_dot(pair).cpu()
            #
            #             layer_wise_angle[layer_name[i]].append(layer_wise_cos)
            #             # layer_wise_dot_value[layer_name[i]].append(layer_wise_dot)
            #             layer_wise_task_angle[layer_name[i]].append(layer_wise_task_cos)
            #             # layer_wise_task_dot_value[layer_name[i]].append(layer_wise_task_dot)
            #
            #             # cos_all = pair_cos_with_d(pair, target_g_list[i])
            #             # ans2 = ((cos_all > opt.freeze_level).sum() == cos_all.size(0))
            #             # if not ans2:
            #             #     pos = layer_dict[layer_name[i]]
            #             #     conflict_dict[layer_name[i]] += 1
            #             #     for j in pos:
            #             #         flag2[j] = False
            #             #     freeze_pos.append(i)
            #             #     cnt2 += 1
            #             #     wandb.run.summary[layer_name[i]] = conflict_dict[layer_name[i]]
            #
            #         rate = 1 - cnt2 / length
            #         wandb.log({'rate2': rate})
            #
            #         if opt.ignore and rate < 1.0:
            #             if rate > 0.0001:
            #                 freeze_network(multi_task_model, flag2)
            #                 g1 = remove_layer_gradient(grad_all[0], freeze_pos)
            #                 g2 = remove_layer_gradient(grad_all[1], freeze_pos)
            #                 grads = [g1, g2]
            #                 grads = torch.stack(grads).t()
            #                 g = cagrad(grads, alpha, rescale=1)
            #
            #                 t1_angle = cos(g, g1)
            #                 t2_angle = cos(g, g2)
            #                 angle = cos(g1, g2)
            #
            #                 wandb.log({'t1_angle_after': t1_angle.item()})
            #                 wandb.log({'t2_angle_after': t2_angle.item()})
            #                 wandb.log({'angle_after': angle.item()})
            #             else:
            #                 g = torch.zeros_like(g)
            #
            #     # grads_ta_all.append(g.data.clone())
            #     overwrite_grad_with_freeze_level(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     unfreeze_network(multi_task_model)
            #
            #     del g
            #     del g1
            #     del g2
            elif method == 'conflict_prob':
                for i in range(2):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()

                if opt.method_sub == 'nothing':
                    g = 0.5 * g1 + 0.5 * g2
                elif opt.method_sub == 'cagrad':
                    g = cagrad(grads, alpha, rescale=1)

                angle = cos(g1, g2)
                angle_all.append(angle)

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
                del g1
                del g2

            elif method == 'branch_layer':
                grad_all = []
                for i in range(2):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    grad = grad2vec_list(multi_task_model)
                    grad = split_layer(grad_list=grad, name_dict=layer_dict)
                    grad_all.append(grad)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()
                # grads_t1_all.append(g1)
                # grads_t2_all.append(g2)
                if opt.method_sub == 'nothing':
                    g = 0.5 * g1 + 0.5 * g2
                elif opt.method_sub == 'cagrad':
                    g = cagrad(grads, alpha, rescale=1)
                elif opt.method_sub == 'graddrop':
                    g = graddrop(grads)
                elif opt.method_sub == 'pcgrad':
                    g = pcgrad(grads, rng)
                elif opt.method_sub == 'mgd':
                    g = mgd(grads)
                # g = 0.5 * g1 + 0.5 * g2

                t1_angle = cos(g, g1)
                t2_angle = cos(g, g2)
                angle = cos(g1, g2)

                # total_cnt += 1
                # if t1_angle < 0:
                #     t1_cnt += 1
                #     wandb.run.summary['t1_freq'] = t1_cnt / total_cnt
                # if t2_angle < 0:
                #     t2_cnt += 1
                #     wandb.run.summary['t2_freq'] = t2_cnt / total_cnt
                # if angle < 0:
                #     t_cnt += 1
                #     wandb.run.summary['t_freq'] = t_cnt / total_cnt

                wandb.log({'t1_angle': t1_angle.item()})
                wandb.log({'t2_angle': t2_angle.item()})
                wandb.log({'angle': angle.item()})

                avg_t1_angle.add(t1_angle.item())
                avg_t2_angle.add(t2_angle.item())
                avg_angle.add(angle.item())

                target_g_list = reshape_grad(g, grad_dims)
                target_g_list = split_layer(target_g_list, name_dict=layer_dict)


                flag2 = [True] * len(grad_dims)
                freeze_pos = []
                cnt2 = 0
                n_tasks = 2

                length = len(grad_all[0])

                pair_grad = []
                for i in range(length):
                    temp = []
                    for j in range(n_tasks):
                        temp.append(grad_all[j][i])
                    temp = torch.stack(temp)
                    pair_grad.append(temp)

                for i in range(length):
                    task_diff_list = LFT_task_cos_diff(grad_all=grad_all, grad_target=target_g_list,
                                                       n_tasks=n_tasks,
                                                       pos=i)
                    diff_list = LFT_cos_diff(grad_all=grad_all, n_tasks=n_tasks, pos=i)
                    task_diff[layer_name[i]].append(task_diff_list)
                    diff[layer_name[i]].append(diff_list)

                for i, pair in enumerate(pair_grad):
                    layer_wise_task_cos = pair_cos_with_d(pair, target_g_list[i]).cpu()
                    # layer_wise_dot = pair_dot_with_d(pair, target_g_list[i]).cpu()
                    layer_wise_cos = pair_cos(pair).cpu()
                    # layer_wise_task_dot = pair_dot(pair).cpu()

                    layer_wise_angle[layer_name[i]].append(layer_wise_cos)
                    # layer_wise_dot_value[layer_name[i]].append(layer_wise_dot)
                    layer_wise_task_angle[layer_name[i]].append(layer_wise_task_cos)

                    # layer_wise_cos = pair_cos_with_d(pair, target_g_list[i]).cpu()
                    # # layer_wise_dot = pair_dot_with_d(pair, target_g_list[i]).cpu()
                    # layer_wise_task_cos = pair_cos(pair).cpu()
                    # # layer_wise_task_dot = pair_dot(pair).cpu()
                    #
                    # layer_wise_angle[layer_name[i]].append(layer_wise_cos)
                    # # layer_wise_dot_value[layer_name[i]].append(layer_wise_dot)
                    # layer_wise_task_angle[layer_name[i]].append(layer_wise_task_cos)

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()

                # if index == opt.start_epoch and turning_cnt <= 0:
                #     turning_cnt += 1
                #     branch_layers = topK_prob(task_diff, topK=opt.topK)
                #     # branch_layers_2 = topK_value(task_diff, topK=opt.topK)
                #     multi_task_model.turn(task_branches=branch_layers)
                #
                #     grad_dims = []
                #     for key, param in multi_task_model.shared_parameters().items():
                #         grad_dims.append(param.data.numel())
                #
                #     layer_dict = get_layer_dict(multi_task_model)
                #     layer_name = list(layer_dict.keys())

                del g
                del g1
                del g2
            # elif method == "cagrad_weight":
            #     for i in range(2):
            #         if i < 2:
            #             train_loss_tmp[i].backward(retain_graph=True)
            #         else:
            #             train_loss_tmp[i].backward()
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #     g = cagrad_weight(grads, alpha, rescale=1)
            #     overwrite_grad(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #     del g
            elif method == 'nothing':
                optimizer.zero_grad()
                loss = sum([train_loss_tmp[i] for i in range(2)])
                loss.backward()
                optimizer.step()

            elif method == 'nothing_v2':
                for i in range(2):
                    train_loss_tmp[i].backward(retain_graph=True)
                    grad2vec(multi_task_model, grads, grad_dims, i)
                    multi_task_model.zero_grad_shared_modules()

                g1 = grads[:, 0].data.clone()
                g2 = grads[:, 1].data.clone()

                g = 0.5 * g1 + 0.5 * g2

                angle = cos(g1, g2)
                angle_all.append(angle)

                overwrite_grad(multi_task_model, g, grad_dims)
                optimizer.step()
                del g
                del g1
                del g2

            # elif method == 'nothing_fix':
            #     grad_all = []
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     # grads_ta_all.append(grads.mean(dim=1).data.clone())
            #     g = grads.mean(dim=1).data.clone()
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #     del g1
            #     del g2
            #     del g
            #
            #     length = len(grad_all[0])
            #     flag2 = [False] * length
            #     cnt2 = 0
            #     n_task = 2
            #
            #     if index > opt.start_epoch:
            #         pair_grad = []
            #         for i in range(length):
            #             temp = []
            #             for j in range(n_task):
            #                 temp.append(grad_all[j][i])
            #             temp = torch.stack(temp)
            #             pair_grad.append(temp)
            #
            #         cos_saved = []
            #         for i, pair in enumerate(pair_grad):
            #             cos_all = pair_cos2(pair)
            #             ans2 = ((cos_all > opt.freeze_level).sum() == cos_all.size(0))
            #             if ans2:
            #                 flag2[i] = True
            #                 cnt2 += 1
            #
            #             cos_saved.append(cos_all)
            #
            #         all_cos_value.append(cos_saved)
            #
            #         wandb.log({'rate2': cnt2 / length})
            #
            #         freeze_network(multi_task_model, flag2)
            #
            #     optimizer.zero_grad()
            #     loss = sum([train_loss_tmp[i] for i in range(2)])
            #     loss.backward()
            #
            #     optimizer.step()
            #     unfreeze_network(multi_task_model)
            #
            # elif method == 'nothing_fix2':
            #     grad_all = []
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     # grads_ta_all.append(grads.mean(dim=1).data.clone())
            #     g = grads.mean(dim=1).data.clone()
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #     del g1
            #     del g2
            #     del g
            #
            #     length = len(grad_all[0])
            #     flag2 = [True] * length
            #     score = []
            #     n_freeze_para = int(opt.freeze_level * length)
            #
            #     cnt2 = 0
            #     n_task = 2
            #
            #     if index > opt.start_epoch:
            #         pair_grad = []
            #         for i in range(length):
            #             temp = []
            #             for j in range(n_task):
            #                 temp.append(grad_all[j][i])
            #             temp = torch.stack(temp)
            #             pair_grad.append(temp)
            #
            #         cos_saved = []
            #         for i, pair in enumerate(pair_grad):
            #             cos_all = pari_cos(pair)
            #             ans = torch.min(cos_all)
            #             score.append(ans)
            #             cos_saved.append(cos_all)
            #
            #         all_cos_value.append(cos_saved)
            #         score = torch.stack(score).view(-1)
            #         min_value, min_pos = torch.topk(score, n_freeze_para, largest=False)
            #         for pos in min_pos:
            #             flag2[pos.item()] = False
            #
            #         wandb.log({'max_freeze_value': min_value[-1].item()})
            #
            #         freeze_network(multi_task_model, flag2)
            #
            #     optimizer.zero_grad()
            #     loss = sum([0.5 * train_loss_tmp[i] for i in range(2)])
            #     loss.backward()
            #
            #     # #debug
            #     # cur=0
            #     # for m in multi_task_model.shared_modules():
            #     #     for p in m.parameters():
            #     #         if flag2[cur] is False:
            #     #             print(p.grad)
            #     #         cur += 1
            #
            #     optimizer.step()
            #     unfreeze_network(multi_task_model)
            # elif method == 'nothing_fix3':
            #     grad_all = []
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     # grads_ta_all.append(grads.mean(dim=1).data.clone())
            #     g = grads.mean(dim=1).data.clone()
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #     del g1
            #     del g2
            #     del g
            #
            #     length = len(grad_all[0])
            #     flag2 = [True] * length
            #     n_freeze_para = int(opt.freeze_level * length)
            #
            #     n_tasks = 2
            #
            #     if index > opt.start_epoch:
            #         test = LFT_min_cos(grad_all, n_tasks, pos=[])
            #
            #         score = []
            #         for i in range(length):
            #             value = LFT_min_cos(grad_all, n_tasks, pos=i)
            #             score.append(value)
            #         score = torch.stack(score).view(-1)
            #         max_value, max_pos = torch.topk(score, n_freeze_para, largest=True)
            #         test2 = LFT_min_cos(grad_all, n_tasks, pos=max_pos.tolist())
            #         for pos in max_pos:
            #             flag2[pos.item()] = False
            #
            #         wandb.log({'former': test.item()})
            #         wandb.log({'After': test2.item()})
            #
            #         freeze_network(multi_task_model, flag2)
            #
            #     optimizer.zero_grad()
            #     loss = sum([train_loss_tmp[i] for i in range(2)])
            #     loss.backward()
            #
            #     optimizer.step()
            #     unfreeze_network(multi_task_model)
            # elif method == 'nothing_fix4':
            #     grad_all = []
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         grad = grad2vec_list(multi_task_model)
            #         grad_all.append(grad)
            #         multi_task_model.zero_grad_shared_modules()
            #
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #     # grads_t1_all.append(g1)
            #     # grads_t2_all.append(g2)
            #     # grads_ta_all.append(grads.mean(dim=1).data.clone())
            #     g = grads.mean(dim=1).data.clone()
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #     del g1
            #     del g2
            #     del g
            #
            #     length = len(grad_all[0])
            #     flag2 = [True] * length
            #     n_freeze_para = int(opt.freeze_level * length)
            #
            #     n_tasks = 2
            #
            #     if index > opt.start_epoch:
            #         test = LFT_min(grad_all, n_tasks, pos=[])
            #         score = []
            #         for i in range(length):
            #             value = LFT_min(grad_all, n_tasks, pos=i)
            #             score.append(value)
            #         score = torch.stack(score).view(-1)
            #         max_value, max_pos = torch.topk(score, n_freeze_para, largest=True)
            #         test2 = LFT_min(grad_all, n_tasks, pos=max_pos.tolist())
            #         for pos in max_pos:
            #             flag2[pos.item()] = False
            #
            #         wandb.log({'former': test.item()})
            #         wandb.log({'After': test2.item()})
            #
            #         freeze_network(multi_task_model, flag2)
            #
            #     optimizer.zero_grad()
            #     loss = sum([0.5 * train_loss_tmp[i] for i in range(2)])
            #     loss.backward()
            #
            #     optimizer.step()
            #     unfreeze_network(multi_task_model)
            #
            # elif method == 'nothing_regularizer':
            #     grad_all = []
            #     shared_grad = []
            #     for mm in multi_task_model.shared_modules():
            #         shared_grad.extend(list(mm.parameters()))
            #
            #     for i in range(2):
            #         g = []
            #         grad = torch.autograd.grad(train_loss_tmp[i], shared_grad, create_graph=True)
            #         for p in grad:
            #             g.append(p.view(-1))
            #
            #         g = torch.cat(g)
            #         grad_all.append(g)
            #
            #     g1 = grad_all[0]
            #     g2 = grad_all[1]
            #
            #     g = 0.5 * g1 + 0.5 * g2
            #
            #     t1_angle = cos(g, g1).item()
            #     t2_angle = cos(g, g2).item()
            #
            #     angle = cos(g1, g2)
            #
            #     regularizer = 1 - angle
            #
            #     wandb.log({'t1_angle': t1_angle})
            #     wandb.log({'t2_angle': t2_angle})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle)
            #     avg_t2_angle.add(t2_angle)
            #     avg_angle.add(angle.item())
            #
            #     running_avg_1.add(t1_angle)
            #     running_avg_2.add(t2_angle)
            #     running_avg.add(angle.item())
            #
            #     del g1
            #     del g2
            #
            #     # print(scheduler.get_last_lr())
            #
            #     optimizer.zero_grad()
            #     loss = sum([0.5 * train_loss_tmp[i] for i in range(2)])
            #
            #     avg = running_avg.item()
            #     if index > opt.start_epoch:
            #         if avg < stage:
            #             loss += regularizer * opt.omega
            #             if avg < stage1:
            #                 stage = stage2
            #         else:
            #             if avg > stage2:
            #                 stage = stage1
            #
            #     loss.backward()
            #     optimizer.step()
            #     del g
            # elif method == 'nothing_regularizer2':
            #     for i in range(2):
            #         train_loss_tmp[i].backward(retain_graph=True)
            #         grad2vec(multi_task_model, grads, grad_dims, i)
            #         multi_task_model.zero_grad_shared_modules()
            #     g1 = grads[:, 0].data.clone()
            #     g2 = grads[:, 1].data.clone()
            #
            #     g = grads.mean(dim=1).data.clone()
            #
            #     t1_angle = cos(g, g1)
            #     t2_angle = cos(g, g2)
            #     angle = cos(g1, g2)
            #
            #     wandb.log({'t1_angle': t1_angle.item()})
            #     wandb.log({'t2_angle': t2_angle.item()})
            #     wandb.log({'angle': angle.item()})
            #
            #     avg_t1_angle.add(t1_angle.item())
            #     avg_t2_angle.add(t2_angle.item())
            #     avg_angle.add(angle.item())
            #
            #
            #     lr = scheduler.get_last_lr()[0]
            #     if t_g is not None and t1_loss is not None and t2_loss is not None:
            #         cos_t1 = (t1_loss - train_loss_tmp[0]) / (t_g * lr * torch.norm(g1))
            #         cos_t2 = (t2_loss - train_loss_tmp[1]) / (t_g * lr * torch.norm(g2))
            #         print(f'c1:{cos_t1}, c2:{cos_t2}')
            #         regularizer = 1 - cos_t1 * cos_t2
            #         print(f'regularizer:{regularizer}')
            #     else:
            #         regularizer = 0
            #
            #     optimizer.zero_grad()
            #     loss = sum([0.5 * train_loss_tmp[i] for i in range(2)])
            #     if index > opt.start_epoch:
            #         loss += regularizer * opt.omega
            #
            #     loss.backward()
            #
            #     del g
            #     del g1
            #     del g2
            #
            #     g = torch.Tensor(sum(grad_dims)).cuda()
            #     grad2vec_g(multi_task_model, g, grad_dims)
            #     optimizer.step()
            #
            #     t_g = torch.norm(g)
            #     t1_loss = train_loss_tmp[0].data.clone()
            #     t2_loss = train_loss_tmp[1].data.clone()
            #     del g
            #
            # elif method == 'nothing_drop':
            #     if opt.flood:
            #         for i in range(2):
            #             train_loss_tmp[i].backward(retain_graph=True)
            #             grad2vec(multi_task_model, grads, grad_dims, i)
            #             multi_task_model.zero_grad_shared_modules()
            #         g1 = grads[:, 0].data.clone()
            #         g2 = grads[:, 1].data.clone()
            #         grads_t1_all.append(g1)
            #         grads_t2_all.append(g2)
            #         # grads_ta_all.append(grads.mean(dim=1).data.clone())
            #         g = grads.mean(dim=1).data.clone()
            #
            #         t1_angle = cos(g, g1)
            #         t2_angle = cos(g, g2)
            #         angle = cos(g1, g2)
            #
            #         wandb.log({'t1_angle': t1_angle.item()})
            #         wandb.log({'t2_angle': t2_angle.item()})
            #         wandb.log({'angle': angle.item()})
            #
            #         avg_t1_angle.add(t1_angle.item())
            #         avg_t2_angle.add(t2_angle.item())
            #         avg_angle.add(angle)
            #
            #         running_avg_1.add(t1_angle.item())
            #         running_avg_2.add(t2_angle.item())
            #
            #
            #         if index > 50:
            #             std_1 = running_avg_1.std()
            #             std_2 = running_avg_2.std()
            #             avg_1 = running_avg_1.item()
            #             avg_2 = running_avg_2.item()
            #
            #             omega1 = 1 - std_1 / avg_1 * opt.omega
            #             omega2 = 1 - std_2 / avg_2 * opt.omega
            #             avg_1_bound = avg_1 * omega1
            #             avg_2_bound = avg_2 * omega2
            #
            #             if t1_angle.item() < avg_1_bound:
            #                 train_loss_tmp[0] = -train_loss_tmp[0]
            #                 wandb.log({'t1_opposite': 1})
            #             else:
            #                 wandb.log({'t1_opposite': 0})
            #
            #             if t2_angle.item() < avg_2_bound:
            #                 train_loss_tmp[1] = -train_loss_tmp[1]
            #                 wandb.log({'t2_opposite': 1})
            #             else:
            #                 wandb.log({'t2_opposite': 0})
            #
            #     optimizer.zero_grad()
            #     loss = sum([0.5 * train_loss_tmp[i] for i in range(2)])
            #     loss.backward()
            #     optimizer.step()
            #
            #     del g1
            #     del g2
            #     del g
            #
            # elif method == 'test':
            #     for i in range(n_classes):
            #         for j in range(2):
            #             train_loss[i][j].backward(retain_graph=True)
            #             grad2vec(multi_task_model, grads, grad_dims, j)
            #             multi_task_model.zero_grad_shared_modules()
            #         grads_t1_all[i].append(grads[:, 0].data.clone())
            #         grads_t2_all[i].append(grads[:, 1].data.clone())
            #
            #     optimizer.zero_grad()
            #     loss = sum([0.5 * train_loss_tmp[i] for i in range(2)])
            #     loss.backward()
            #     optimizer.step()
            # elif method == 'random':
            #     prob_t1 = np.exp(-train_loss_tmp[0].item() / 1.1 * opt.sigma)
            #     cur = np.random.uniform(high=1.0)
            #     if prob_t1 > cur:
            #         train_loss_tmp[0] = -train_loss_tmp[0]
            #
            #     optimizer.zero_grad()
            #     loss = sum([0.5 * train_loss_tmp[i] for i in range(2)])
            #     loss.backward()
            #     optimizer.step()

        # if len(loss_list_1) > 0:
        #     loss_list_1 = np.array(loss_list_1)
        #     wandb.run.summary["loss_1_min"] = loss_list_1.min()
        #     wandb.run.summary["loss_1_max"] = loss_list_1.max()
        #     wandb.run.summary["loss_1_median"] = np.median(loss_list_1)
        #     L1_alpha = np.median(loss_list_1) * 0.9
        # if len(loss_list_2) > 0:
        #     loss_list_2 = np.array(loss_list_2)
        #     wandb.run.summary["loss_2_min"] = loss_list_2.min()
        #     wandb.run.summary["loss_2_max"] = loss_list_2.max()
        #     wandb.run.summary["loss_2_median"] = np.median(loss_list_2)
        #     L2_alpha = np.median(loss_list_2) * 0.9
        #
        # if method == 'test':
        #     for i in range(n_classes):
        #         grads_t1_all[i] = torch.stack(grads_t1_all[i])
        #         # grads_t1_all[i] = F.normalize(grads_t1_all[i], dim=1)
        #         grads_t2_all[i] = torch.stack(grads_t2_all[i])
        #         # grads_t2_all[i] = F.normalize(grads_t2_all[i], dim=1)
        #
        #     # grads_t1_all = torch.stack(grads_t1_all)
        #     # grads_t2_all = torch.stack(grads_t2_all)
        #     #
        #     # grads_t1_all_T = grads_t1_all.permute((0, 2, 1))
        #     # grads_t2_all_T = grads_t2_all.permute((0, 2, 1))
        #     #
        #     # grads_t1_all = grads_t1_all.unsqueeze(dim=1)
        #     # grads_t2_all = grads_t2_all.unsqueeze(dim=1)
        #     #
        #     # score_t1 = torch.matmul(grads_t1_all, grads_t1_all_T)
        #     # score_t2 = torch.matmul(grads_t2_all, grads_t2_all_T)
        #     #
        #     # # score_t1_sum = 0
        #     # # score_t2_sum = 0
        #     # # cnt = 0
        #     # # for i in range(n_classes):
        #     # #     for j in range(n_classes):
        #     # #         if i != j:
        #     # #             score_t1_sum += score_t1[i][j].mean()
        #     # #             score_t2_sum += score_t2[i][j].mean()
        #     # #             cnt += 1.0
        #     # #
        #     # # score_t1 = score_t1_sum / cnt
        #     # # score_t2 = score_t2_sum / cnt
        #     #
        #     # score_t1 = score_t1.mean()
        #     # score_t2 = score_t2.mean()


        # t1_score = 0.0
        # with torch.no_grad():
        #     for i in range(len(grads_t1_all) - 1):
        #         for j in range(i+1, len(grads_t2_all)):
        #             score = cos(grads_t1_all[i], grads_t2_all[j])
        #             t1_score += score.item()
        #
        # t2_score = 0.0
        # with torch.no_grad():
        #     for i in range(len(grads_t2_all) - 1):
        #         for j in range(i+1, len(grads_t1_all)):
        #             score = cos(grads_t2_all[i], grads_t1_all[j])
        #             t2_score += score.item()

        # grads_t1_all = torch.stack(grads_t1_all)
        # grads_t1_all = F.normalize(grads_t1_all, dim=1)
        #
        # grads_t2_all = torch.stack(grads_t2_all)
        # grads_t2_all = F.normalize(grads_t2_all, dim=1)

        # grads_ta_all = torch.stack(grads_ta_all)
        # grads_ta_all = F.normalize(grads_ta_all, dim=1)


        # score = torch.mm(grads_t1_all, grads_t2_all.T)
        # score = score.mean()
        #
        # score_t1 = torch.mm(grads_t1_all, grads_t1_all.T)
        # score_t1 = score_t1.mean()
        # score_t1 = score + score_t1

        # score_t2 = torch.mm(grads_t2_all, grads_t2_all.T)
        # score_t2 = score_t2.mean()
        # score_t2 = score + score_t2

        # n_epoch = len(grads_t1_all)
        # score = torch.triu(score, diagonal=1)
        # score = score.sum() / (n_epoch * (n_epoch + 1) / 2)


        # score = torch.median(score)

        # temp = torch.mm(grads_t1_all, grads_ta_all.T)
        # n_epoch = len(grads_ta_all)
        # score_t1 = (temp.sum() - torch.diag(temp).sum()) / (n_epoch * (n_epoch - 1))
        #
        # temp = torch.mm(grads_t2_all, grads_ta_all.T)
        # n_epoch = len(grads_ta_all)
        # score_t2 = (temp.sum() - torch.diag(temp).sum()) / (n_epoch * (n_epoch - 1))

        print(f'Epoch:{index} | ETA:{timer.measure()}/{timer.measure((index + 1) / epoch)} | lr: {get_lr(optimizer=optimizer)}')
        print(f'----Train----| Train_loss_t1: {avg_tr_loss_1.item()}, Train_loss_t2: {avg_tr_loss_2.item()} | Train_acc_t1: {avg_tr_acc_1.item()}, Train_acc_t2: {avg_tr_acc_2.item()}')
        # print(f'----Train----| Score1: {score_t1} | Score2: {score_t2}')

        log_dict = {
            'epoch': index,
            'Train_loss_t1_avg': avg_tr_loss_1.item(),
            'Train_loss_t2_avg': avg_tr_loss_2.item(),
            'Train_acc_t1': avg_tr_acc_1.item(),
            'Train_acc_t2': avg_tr_acc_2.item(),
            'avg_t1_angle': avg_t1_angle.item(),
            'avg_t2_angle': avg_t2_angle.item(),
            'avg_angle': avg_angle.item()
        }

        wandb.log(log_dict)
        # wandb.log({'score_t1': score_t1})
        # wandb.log({'score_t2': score_t2})
        # wandb.log({'score': score})
        if total_cnt > 0:
            wandb.run.summary["t1_opposite_rate"] = 1.0 * t1_cnt / total_cnt
            wandb.run.summary["t2_opposite_rate"] = 1.0 * t2_cnt / total_cnt

        if (index + 1) % eval_freq == 0:
            # evaluating test data
            multi_task_model.eval()
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    x = data[0].cuda()
                    ts = data[1].cuda()

                    t1_pred, t2_pred = multi_task_model(x)
                    test_loss = [ce_loss(t1_pred, ts[:, 0]), ce_loss(t2_pred, ts[:, 1])]
                    n_samples = t1_pred.size(0)

                    # calculate the accuracy of prediction
                    t1_estimate_labels = torch.argmax(t1_pred, dim=1)
                    t2_estimate_labels = torch.argmax(t2_pred, dim=1)

                    acc_t1 = (t1_estimate_labels == ts[:, 0]).sum() / n_samples
                    acc_t2 = (t2_estimate_labels == ts[:, 1]).sum() / n_samples

                    avg_te_loss_1.add(test_loss[0].item(), n_samples)
                    avg_te_loss_2.add(test_loss[1].item(), n_samples)
                    avg_te_acc_1.add(acc_t1.item(), n_samples)
                    avg_te_acc_2.add(acc_t2.item(), n_samples)

            print(
                f'----Test----| Test_loss_t1: {avg_te_loss_1.item()}, Test_loss_t2: {avg_te_loss_2.item()} | Test_acc_t1: {avg_te_acc_1.item()}, Test_acc_t2: {avg_te_acc_2.item()}')

            log_dict = {
                'epoch': index,
                'Test_loss_t1': avg_te_loss_1.item(),
                'Test_loss_t2': avg_te_loss_2.item(),
                'Test_acc_t1': avg_te_acc_1.item(),
                'Test_acc_t2': avg_te_acc_2.item(),
            }

            wandb.log(log_dict)


        del grads
        del x
        del ts
        del t1_pred
        del t2_pred
        del train_loss_tmp
        del t1_estimate_labels
        del t2_estimate_labels
        del acc_t1
        del acc_t2
        del n_samples

        scheduler.step()

        # saved_dict = {'cos': layer_wise_angle,
        #               'task_cos': layer_wise_task_angle}

        if (index + 1) == epoch and method == 'branch_layer':
            saved_dict = {'task_diff': task_diff}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{opt.optimizer}_{epoch}_task_diff_data_mnist.pt')
            saved_dict = {'diff': diff}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{opt.optimizer}_{epoch}_diff_data_mnist.pt')
            saved_dict = {'cos': layer_wise_angle}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{opt.optimizer}_{epoch}_lw_cos_mnist.pt')
            saved_dict = {'task_cos': layer_wise_task_angle}
            torch.save(saved_dict, f'./saved/{opt.seed}{opt.method_sub}{opt.optimizer}_{epoch}_lw_task_cos_mnist.pt')

        if (index + 1) == epoch:
            if opt.base_model == 'fw_ablation':
                torch.save(angle_all,
                           f'./saved/{opt.seed}{opt.base_model}{opt.flag}_{opt.topK}{opt.method}{opt.optimizer}_{epoch}_angle.pt')
            else:
                torch.save(angle_all, f'./saved/{opt.seed}{opt.base_model}{opt.topK}{opt.method}{opt.optimizer}_{epoch}_angle.pt')

def single_task_mgd_trainer(train_loader, test_loader, single_task_model,
                           optimizer, scheduler, opt):
    epoch = opt.n_epoch
    task_id = opt.task_id
    eval_freq = opt.eval_freq

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    timer = Timer()
    ce_loss = CrossEntropyLoss()


    for index in range(epoch):
        avg_tr_loss = Averager()
        avg_tr_acc = Averager()
        avg_te_loss = Averager()
        avg_te_acc = Averager()

        single_task_model.train()

        wandb.define_metric("epoch")
        wandb.define_metric("Train_loss_avg", step_metric="epoch")
        wandb.define_metric("Train_acc_avg", step_metric="epoch")
        wandb.define_metric("Test_loss", step_metric="epoch")
        wandb.define_metric("Test_acc", step_metric="epoch")


        for k, data in enumerate(train_loader):
            x = data[0].cuda()
            ts = data[1].cuda()

            t1_pred, t2_pred = single_task_model(x)
            pred = [t1_pred, t2_pred]


            n_samples = pred[task_id].size(0)

            train_loss = ce_loss(pred[task_id], ts[:, task_id])

            wandb.log({'Train_loss': train_loss.item()})

            avg_tr_loss.add(train_loss.item(), n_samples)


            estimate_labels = torch.argmax(pred[task_id], dim=1)

            acc = (estimate_labels == ts[:, task_id]).sum() / n_samples
            avg_tr_acc.add(acc.item(), n_samples)


            optimizer.zero_grad()

            train_loss.backward()
            optimizer.step()


        print(f'Epoch:{index} | ETA:{timer.measure()}/{timer.measure((index + 1) / epoch)} | lr: {get_lr(optimizer=optimizer)}')
        print(f'----Train----| Train_loss: {avg_tr_loss.item()}| Train_acc_t1: {avg_tr_acc.item()}')

        if task_id == 0:
            log_dict = {
                'epoch': index,
                'Train_loss_t1_avg': avg_tr_loss.item(),
                'Train_acc_t1_avg': avg_tr_acc.item(),
            }
        elif task_id == 1:
            log_dict = {
                'epoch': index,
                'Train_loss_t2_avg': avg_tr_loss.item(),
                'Train_acc_t2_avg': avg_tr_acc.item(),
            }

        wandb.log(log_dict)


        if (index + 1) % eval_freq == 0:
            # evaluating test data
            single_task_model.eval()
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    x = data[0].cuda()
                    ts = data[1].cuda()

                    t1_pred, t2_pred = single_task_model(x)
                    pred = [t1_pred, t2_pred]

                    n_samples = pred[task_id].size(0)
                    test_loss = ce_loss(pred[task_id], ts[:, task_id])

                    estimate_labels = torch.argmax(pred[task_id], dim=1)

                    acc = (estimate_labels == ts[:, task_id]).sum() / n_samples
                    avg_te_acc.add(acc.item(), n_samples)


                    avg_te_loss.add(test_loss.item(), n_samples)
                    avg_te_acc.add(acc.item(), n_samples)


            print(
                f'----Test----| Test_loss: {avg_te_loss.item()} | Test_acc: {avg_te_acc.item()}')

            if task_id == 0:
                log_dict = {
                    'epoch': index,
                    'Test_loss_t1': avg_te_loss.item(),
                    'Test_acc_t1': avg_te_acc.item()
                }
            elif task_id == 1:
                log_dict = {
                    'epoch': index,
                    'Test_loss_t2': avg_te_loss.item(),
                    'Test_acc_t2': avg_te_acc.item()
                }

            wandb.log(log_dict)

        del x
        del ts
        del t1_pred
        del t2_pred
        del n_samples

        scheduler.step()

def train_rotograd(train_loader, test_loader, model,
                           optimizer_main, scheduler_main, optimizer_R, scheduler_R, opt):
    epoch = opt.n_epoch
    eval_freq = opt.eval_freq
    timer = Timer()
    ce_loss = CrossEntropyLoss()

    for index in range(epoch):
        model.train()

        avg_te_loss_1 = Averager()
        avg_te_acc_1 = Averager()
        avg_te_loss_2 = Averager()
        avg_te_acc_2 = Averager()

        for k, data in enumerate(train_loader):
            x = data[0].cuda()
            ts = data[1].cuda()

            t1_pred, t2_pred = model(x)

            train_loss_tmp = [ce_loss(t1_pred, ts[:, 0]), ce_loss(t2_pred, ts[:, 1])]

            wandb.log({'Train_loss_t1': train_loss_tmp[0].item()})
            wandb.log({'Train_loss_t2': train_loss_tmp[1].item()})

            optimizer_main.zero_grad()
            optimizer_R.zero_grad()

            model.backward(train_loss_tmp)

            # loss = sum([train_loss_tmp[i] for i in range(2)])
            # loss.backward()

            optimizer_main.step()
            optimizer_R.step()

        if (index + 1) % eval_freq == 0:
            # evaluating test data
            model.eval()
            with torch.no_grad():  # operations inside don't track history
                for k, data in enumerate(test_loader):
                    x = data[0].cuda()
                    ts = data[1].cuda()

                    t1_pred, t2_pred = model(x)
                    test_loss = [ce_loss(t1_pred, ts[:, 0]), ce_loss(t2_pred, ts[:, 1])]
                    n_samples = t1_pred.size(0)

                    # calculate the accuracy of prediction
                    t1_estimate_labels = torch.argmax(t1_pred, dim=1)
                    t2_estimate_labels = torch.argmax(t2_pred, dim=1)

                    acc_t1 = (t1_estimate_labels == ts[:, 0]).sum() / n_samples
                    acc_t2 = (t2_estimate_labels == ts[:, 1]).sum() / n_samples

                    avg_te_loss_1.add(test_loss[0].item(), n_samples)
                    avg_te_loss_2.add(test_loss[1].item(), n_samples)
                    avg_te_acc_1.add(acc_t1.item(), n_samples)
                    avg_te_acc_2.add(acc_t2.item(), n_samples)

            print(
                f'----Test----| Test_loss_t1: {avg_te_loss_1.item()}, Test_loss_t2: {avg_te_loss_2.item()} | Test_acc_t1: {avg_te_acc_1.item()}, Test_acc_t2: {avg_te_acc_2.item()}')

            log_dict = {
                'epoch': index,
                'Test_loss_t1': avg_te_loss_1.item(),
                'Test_loss_t2': avg_te_loss_2.item(),
                'Test_acc_t1': avg_te_acc_1.item(),
                'Test_acc_t2': avg_te_acc_2.item(),
            }

            wandb.log(log_dict)

        scheduler_main.step()
        scheduler_R.step()
