import random, os, time,sys
from abc import ABC, abstractmethod
from comb_modules.utils import cached_vertex_grid_to_edges, cached_vertex_grid_to_edges_grid_coords
import copy

import torch
import torch.nn as nn
from comb_modules.losses import HammingLoss
from comb_modules.dijkstra import ShortestPath, ShortestPathDPO, dijkstra, get_solver # JK 02/21
# from logger import Logger
from models import get_model
from utils import AverageMeter, optimizer_from_string, customdefaultdict, create_LP_matrix, create_adj, nodewt_to_edgewt, gini_indices, gini_coefficient, gini_indices_square, maybe_parallelize
from decorators import to_tensor, to_numpy
from . import metrics
from .metrics import compute_metrics
import numpy as np
from collections import defaultdict
#sys.path.insert(0,'../..')
#sys.path.append("../../")
sys.path.append("../../fair_ltr/")
from frank_wolfe import compute_Moreau_grad_softsort, compute_owa
def get_trainer(trainer_name):
    trainers = {"Baseline": BaselineTrainer, "DijkstraOnFull": DijkstraOnFull,  "DijkstraSPO": DijkstraSPO, 
    "DijkstraDescent": DijkstraDescent,
    "DijkstraOWADescent": DijkstraOWADescent,
    "DijkstraMultiOWADescent": DijkstraMultiOWADescent, # use grad of owa methods
    "DijkstraMultiOWADescent2": DijkstraMultiOWADescent2,
    "DijkstraMultiDescent": DijkstraMultiDescent, # sum method
    "DijkstraMultiGradNormDescent": DijkstraMultiGradNormDescent, # grad norm method, 
    "BaselineMulti": BaselineTrainerMulti,
    }
    return trainers[trainer_name]
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from .visualization import draw_paths_on_image




#JK
from qpth.qp import QPFunction
from perturbations import perturbations

#import sys
#sys.path.insert(0,'./NeurIPSIntopt/Interior/')
#sys.path.insert(0,'../..')
sys.path.append('./NeurIPSIntopt/Interior/')
#from ip_model_whole import *

class ShortestPathAbstractMulitTrainer(ABC): 
    def __init__(
        self,
        *,
        train_iterator,
        test_iterator,
        metadata,
        use_cuda,
        batch_size,
        optimizer_name,
        optimizer_params,
        model_params,
        fast_mode,
        neighbourhood_fn,
        preload_batch,
        lr_milestone_1,
        lr_milestone_2,
        use_lr_scheduling,
        normalize_path=0,
        owa_weight='gini',
        beta=0.1,
    ):
        self.use_cuda = use_cuda
        self.optimizer_params = optimizer_params
        self.batch_size = batch_size
        self.test_iterator = test_iterator
        self.train_iterator = train_iterator
        self.metadata = metadata
        self.n_task = model_params.arch_params.n_task

        self.grid_dim = int(np.sqrt(self.metadata["output_features"]))
        self.neighbourhood_fn = neighbourhood_fn
        self.preload_batch = preload_batch
        self.build_model(**model_params)
        self.optimizer = optimizer_from_string(optimizer_name)(self.model.parameters(), **optimizer_params)
        self.use_lr_scheduling = use_lr_scheduling
        self.normalize_path = normalize_path

        self.beta = beta
        if use_lr_scheduling:
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, min_lr=1e-7, verbose=True, patience=50)
        self.epochs = 0
        self.iter = 0 
        if owa_weight == 'one': 
            self.w_species = torch.ones(self.n_task).float()
            print('Generated owa weight of one.')
        elif owa_weight == 'gini': 
            self.w_species = gini_indices(self.n_task) 
            print('Generated owa weight of gini.')
        elif owa_weight == 'ginisquare': 
            self.w_species = gini_indices_square(self.n_task) 
            print('Generated owa weight of gini square.')
        self.alpha =0.16 # hyperparam of grad norm
        self.loss_func =  nn.MSELoss(reduction='none')
        self.spsolver = get_solver(neighbourhood_fn)
        self.training_time = 0 
        self.best = 1e+8
        self.patience = 10
        self.counter = 0
        self.early_stopping = False
        self.best_model = None
        self.true_path_lens_species = {}
        for i in range(self.n_task): 
            self.true_path_lens_species['species_{}'.format(i)] = []


    def train_epoch(self):
        Gradloss = nn.L1Loss()

        batch_time = AverageMeter("Batch time")
        data_time = AverageMeter("Data time")
        cuda_time = AverageMeter("Cuda time")
        avg_loss = AverageMeter("Loss")
        avg_accuracy = AverageMeter("Accuracy")
        avg_perfect_accuracy = AverageMeter("Perfect Accuracy")
        avg_path_diff = AverageMeter("Path length diff")
        avg_owa_path_lens = AverageMeter("OWA Path length")
        avg_path_lens = AverageMeter("Path length")

        avg_metrics = customdefaultdict(lambda k: AverageMeter("train_"+k))

        batch_loss_list   = []
        batch_path_lens_diff_list = []
        batch_path_lens_list, batch_owa_path_lens_list, batch_path_lens_species_list = [],[], []
        val_path_lens_list, val_owa_path_lens_list, val_path_lens_diff_list, val_path_lens_species_list = [],[], [],[]
        val_path_lens_gini_list = []
        val_regret = []
        training_interval = []
        end = time.time()
        if self.batch_size <=50: 
            write_losses_interval = 1
        else: 
            write_losses_interval = 5


        if self.epochs == 0: 
            eval_results = self.evaluate()
            val_owa_path_lens_list.append(eval_results['owa_path_lens'])
            val_path_lens_list.append(eval_results['path_lens'])

            print('='*20)
            print("Evaluating on test set: iteration {} of epoch {}: owa_path_lens {}\t Path_lens {}\t Gini pathlen: {}".
                               format(self.iter, self.epochs, eval_results['owa_path_lens'],eval_results['path_lens'], eval_results['gini_path_lens']))
            val_path_lens_diff_list.append(eval_results['path_lens_diff'])
            val_path_lens_species_list.append([eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ])
            val_path_lens_gini_list.append(eval_results['gini_path_lens'])
            training_interval = [0]
            print('Test: path len of each species ',[eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ])

        print('='*100)
        self.model.train()
        for data in zip(*self.train_iterator):
            start_time = time.time()

            xs = torch.stack([d[0] for d in data]) # feat
            cs = torch.stack([d[1] for d in data]) # cost
            ws = torch.stack([d[2] for d in data]) # path
            # zs = torch.stack([d[3] for d in data]) # pathlen
            if self.use_cuda:
                xs, cs, ws = xs.cuda(), cs.cuda(), ws.cuda()
            if self.name == 'BaselineMulti':
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(xs, cs)
            else: 
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(xs, ws)

            avg_loss.update(loss.item(), xs.size(0))
            avg_accuracy.update(accuracy.item(), xs.size(0))
            batch_loss_list.append(loss.item())

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            suggested_path = last_suggestion["suggested_path"]

            if self.name != 'BaselineMulti': 
                path_lens = torch.einsum("mijk, mijk->mi",suggested_path, cs) # numspecies x batch
                path_lens = torch.einsum("mi->im", path_lens)# batch x species
            
            if (self.normalize_path) & (self.name != 'BaselineMulti'): 
                true_path_lens = torch.einsum("mijk, mijk->mi", ws, cs)
                true_path_lens = torch.einsum("mi->im", true_path_lens)# 
                path_lens = (path_lens/true_path_lens)
                path_lens2 = path_lens.mean(0)

            if self.name == 'DijkstraMultiGradNormDescent':
                path_lens_batch_sum = path_lens.sum(dim=0) # dim 1 x n_task
                if self.iter ==0:
                    self.weights = torch.nn.Parameter(torch.ones_like(path_lens_batch_sum))
                    self.layer=self.model.block
                    # T = self.weights.sum().detach() # sum of weights
                    self.weight_optim = torch.optim.Adam([self.weights], lr= 0.005)
                    self.l0 = path_lens_batch_sum.data

                weighted_loss = self.weights @ path_lens_batch_sum # compute the weighted loss
                weighted_loss.backward(retain_graph=True) #backward pass for weigthted task loss

                gw = []    # compute the L2 norm of the gradients for each task
                for i in range(len(path_lens_batch_sum)):
                    li = self.weights[i]*path_lens_batch_sum[i]
                    dl = torch.autograd.grad(li, self.layer.parameters(), retain_graph=True, create_graph=True)
                    gw.append(torch.norm(dl[0],2))

                gw = torch.stack(gw)
                loss_ratio = weighted_loss/ self.l0# compute loss ratio per task
                rt = loss_ratio / loss_ratio.mean()           # compute the relative inverse training rate per task
                gw_avg = gw.mean()                 # compute the average gradient norm
                constant = (gw_avg * rt ** self.alpha).detach()    
                gradnorm_loss = Gradloss(gw,constant).sum()# compute the GradNorm loss
                self.weight_optim.zero_grad()
                gradnorm_loss.backward()                # backward pass for GradNorm
                self.weight_optim.step()
                # renormalize weights
                coef = (self.n_task/self.weights.data.sum()) 
                self.weights.data =  self.weights.data* coef
                # weights = torch.nn.Parameter(weights)
                # self.weight_optim = torch.optim.Adam([weights], lr=0.005)

                batch_path_lens_species_list.append(path_lens.detach().mean(0).numpy())
                batch_owa_path_lens_list.append(-(compute_owa(self.w_species, -path_lens.detach()).mean()).item())
                avg_owa_path_lens.update(-(compute_owa(self.w_species, (-path_lens.detach())).mean()).item(), xs.size(0))
                cur_criteria = path_lens.detach().sum(dim=-1).mean().item()
                avg_path_lens.update(path_lens.detach().sum(dim=-1).mean().item(), xs.size(0))

            elif self.name == "DijkstraMultiOWADescent":
                with torch.no_grad():
                    grad = compute_Moreau_grad_softsort(self.w_species,(-path_lens/self.beta))
                    path_lens_diff = (path_lens - path_lens.mean(1).view(-1,1)).abs()
                    batch_path_lens_species_list.append(path_lens.detach().mean(0).numpy())
                    batch_owa_path_lens_list.append(-(compute_owa(self.w_species, -path_lens).mean()).item())
                    batch_path_lens_list.append(path_lens.mean().item())
                    batch_path_lens_diff_list.append(path_lens_diff.mean(1).mean().item())
                    avg_path_diff.update(path_lens_diff.mean(1).mean().item(), xs.size(0))
                    avg_owa_path_lens.update(-(compute_owa(self.w_species, (-path_lens)).mean()).item(), xs.size(0))
                    cur_criteria = path_lens.sum(dim=-1).mean().item()
                    avg_path_lens.update(cur_criteria, xs.size(0))

                path_lens.backward(gradient=grad)
            elif self.name == "DijkstraMultiOWADescent2": 
                # path_lens_sum = path_lens.sum(dim=-1)
                # print('path_lens_sum', path_lens_sum.shape)
                loss_multi = MoreauOWALossLayer.apply(path_lens, self.w_species, self.beta)
                loss = loss_multi.mean()
                with torch.no_grad():
                    path_lens_diff = (path_lens - path_lens.mean(1).view(-1,1)).abs()
                    batch_path_lens_species_list.append(path_lens.detach().mean(0).numpy())
                    batch_owa_path_lens_list.append(-(compute_owa(self.w_species, -path_lens).mean()).item())
                    batch_path_lens_list.append(path_lens.mean().item())
                    batch_path_lens_diff_list.append(path_lens_diff.mean(1).mean().item())
                    avg_path_diff.update(path_lens_diff.mean(1).mean().item(), xs.size(0))
                    avg_owa_path_lens.update(-(compute_owa(self.w_species, (-path_lens)).mean()).item(), xs.size(0))
                    cur_criteria = path_lens.sum(dim=-1).mean().item()
                    avg_path_lens.update(cur_criteria, xs.size(0))
                loss.backward()

            elif self.name == "DijkstraMultiDescent":
                path_lens_sum = path_lens.sum(dim=-1).mean()
                path_lens_sum.backward()
                batch_path_lens_species_list.append(path_lens.detach().mean(0).numpy())
                batch_owa_path_lens_list.append(-(compute_owa(self.w_species, -path_lens.detach()).mean()).item())
                avg_owa_path_lens.update(-(compute_owa(self.w_species, (-path_lens.detach())).mean()).item(), xs.size(0))
                cur_criteria = path_lens_sum.item()
                avg_path_lens.update(cur_criteria, xs.size(0))
            elif self.name == 'BaselineMulti': 
                loss.backward()
                with torch.no_grad(): 
                    path_lens = torch.einsum("mijk, mijk->mi",suggested_path, cs) # numspecies x batch
                    path_lens = torch.einsum("mi->im", path_lens)# batch x species
                    batch_path_lens_species_list.append(path_lens.detach().mean(0).numpy())
                    batch_owa_path_lens_list.append(-(compute_owa(self.w_species, -path_lens.detach()).mean()).item())
                    avg_owa_path_lens.update(-(compute_owa(self.w_species, (-path_lens.detach())).mean()).item(), xs.size(0))
                    cur_criteria = path_lens.sum(dim=-1).mean().item()
                    avg_path_lens.update(cur_criteria, xs.size(0))

            self.optimizer.step()
            batch_time.update( time.time() - start_time,  xs.size(0))

            self.training_time += time.time() - start_time
            
            if self.iter % write_losses_interval == 0:
                print("Evaluating on train set: iteration {} of epoch {}:".
                       format(self.iter, self.epochs ))
                meters = [batch_time, data_time, avg_loss, avg_accuracy, avg_path_lens, avg_owa_path_lens]
                meter_str = "\t".join([str(meter) for meter in meters])
                print(f"Epoch: {self.epochs}\t{meter_str}")
                print("SGD lr=%.4f" % (self.optimizer.param_groups[0]["lr"]))
                print('Train: path len of each species ',np.stack(batch_path_lens_species_list).mean(0))

                eval_results = self.evaluate()
                val_path_lens_list.append(eval_results['path_lens'])

                val_owa_path_lens_list.append(eval_results['owa_path_lens'])
                print('='*20)
                print("Evaluating on test set: iteration {} of epoch {}: owa_path_lens {} path_lens {}\t Gini pathlen: {}".
                                   format(self.iter, self.epochs, eval_results['owa_path_lens'],eval_results['path_lens'], eval_results['gini_path_lens']))
                val_path_lens_diff_list.append(eval_results['path_lens_diff'])
                val_path_lens_species_list.append([eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ])
                val_path_lens_gini_list.append(eval_results['gini_path_lens'])
                predicted_path_len = [eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ]
                val_regret.append(eval_results['regret'])
                training_interval.append(time.time() - start_time)

                print('Test: path len of each species ',predicted_path_len, np.sum(predicted_path_len))
                print('Actual path_len prediction:', [predicted_path_len[i] * self.true_path_lens_species['species_{}'.format(i)] if self.normalize_path else predicted_path_len[i] for i in range(self.n_task)  ])
                if eval_results['owa_path_lens'] > (self.best - 1e-4): 
                    self.best = eval_results['path_lens']
                    self.counter = 0 
                    self.best_model = copy.deepcopy(self.model)
                    print('=========> saving new best results. ')
                else: 
                    self.counter +=1

            if self.use_lr_scheduling:
                self.scheduler.step(avg_owa_path_lens.avg) #M
  
            self.iter +=1
            print(self.iter)
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
        # for key, avg_metric in avg_metrics.items():
        #     self.train_logger.log(avg_metric.avg, key=key)
        self.epochs += 1
        if self.counter == self.patience: 
            print('Reached early stopping.')
            self.early_stopping = True 
        return {
            "batch_time":batch_time.avg, 
            "batch_loss_list":batch_loss_list,
            "batch_path_lens_list": batch_owa_path_lens_list, 
            "batch_owa_path_lens_list": batch_owa_path_lens_list, 
            "batch_path_lens_diff_list": batch_path_lens_diff_list,
            "batch_path_lens_species": np.stack(batch_path_lens_species_list).mean(0),
            "train_loss": avg_loss.avg,
            "val_path_lens_list": val_path_lens_list, 
            "val_owa_path_lens_list": val_owa_path_lens_list, 
            "val_path_lens_diff_list": val_path_lens_diff_list, 
            "val_path_lens_species_list": val_path_lens_species_list, 
            "val_path_lens_gini_list": val_path_lens_gini_list,
            'training_interval': training_interval, 
            "val_regret": val_regret,
            **{"train_"+k: avg_metrics[k].avg for k in avg_metrics.keys()}
        }

    def evaluate(self, is_test=False):
        avg_metrics = defaultdict(AverageMeter)

        self.model.eval()

        batch_path_lens_task, batch_owa_path_lens_task= [], []
        xs, cs, ws = [], [], []

        for dat in self.test_iterator:
            xs.append(dat.dataset.tmaps)
            cs.append(dat.dataset.costs)
            ws.append(dat.dataset.paths)
        xs, cs, ws= torch.FloatTensor(np.stack(xs)), torch.FloatTensor(np.stack(cs)), torch.FloatTensor(np.stack(ws))

        # for data in zip(*self.test_iterator):
        #     xs = torch.stack([d[0] for d in data]) # feat
        #     cs = torch.stack([d[1] for d in data]) # cost
        #     ws = torch.stack([d[2] for d in data]) # path
        #     # zs = torch.stack([d[3] for d in data]) # pathlen
        if self.use_cuda:
            xs, cs, ws = xs.cuda(), cs.cuda(), ws.cuda()
        if is_test: 
            self.model = self.best_model
        loss, accuracy, last_suggestion, pred_weights = self.forward_pass(xs, ws)
        suggested_path = last_suggestion["suggested_path"]

        true_path_lens = torch.einsum("mijk, mijk->mi", ws, cs)
        true_path_lens = torch.einsum("mi->im", true_path_lens)
        if self.epochs == 0:
            for s in range(self.n_task):  
                self.true_path_lens_species['species_{}'.format(s)] = true_path_lens[:, s].mean().numpy()
            self.owa_true_path_lens =  -(compute_owa(self.w_species, (-true_path_lens)).mean()).item()

        # print('check', cs[0][0], cs[1][0], cs[2][0])
        with torch.no_grad():
            path_lens = torch.einsum("mijk, mijk->mi",suggested_path, cs) # numspecies x batch
            path_lens = torch.einsum("mi->im", path_lens)
            owa_pred_path_lens =  -(compute_owa(self.w_species, (-path_lens)).mean()).item()
            if self.normalize_path: 
                path_lens = path_lens/true_path_lens
            path_lens_diff = (path_lens - path_lens.mean(1).view(-1,1)).abs()
            batch_path_lens_task.append((path_lens).detach().mean(0).numpy())
            batch_owa_path_lens_task.append(-(compute_owa(self.w_species, (-path_lens)).mean()).item())

        avg_metrics["loss"].update(loss.mean().item(), xs.size(0))          # JK 0630
        avg_metrics["accuracy"].update(accuracy.item(), xs.size(0))
        avg_metrics["path_lens_diff"].update(path_lens_diff.mean(1).mean().item(), xs.size(0))   # M 05/23
        for i in range(self.n_task): 
            avg_metrics['path_lens_species_{}'.format(i)].update(np.stack(batch_path_lens_task)[:,i].mean(), xs.size(0)) 
        avg_metrics['gini_path_lens'].update(gini_coefficient(path_lens.numpy()).mean(), xs.size(0))
        avg_metrics["owa_path_lens"].update(-(compute_owa(self.w_species, (-path_lens)).mean()).item(), xs.size(0))   # M 5/23
        avg_metrics["path_lens"].update(path_lens.sum(dim=-1).mean().item(), xs.size(0))   # M 5/23
        avg_metrics["regret"].update(owa_pred_path_lens - self.owa_true_path_lens, 1)
        print('avg_regret', avg_metrics["regret"], owa_pred_path_lens - self.owa_true_path_lens)
            # if self.fast_mode:
            #     break

        # for key, avg_metric in avg_metrics.items():
        #     self.val_logger.log(avg_metric.avg, key=key)
        print('true avg path_lengths of each species')
        for s in range(self.n_task): 
            if self.epochs == 0 : 
                self.true_path_lens_species['species_{}'.format(s)] = np.mean(self.true_path_lens_species['species_{}'.format(s)])
                print('species: {}'.format(s), self.true_path_lens_species['species_{}'.format(s)])
            else: 
                print('species: {}'.format(s), self.true_path_lens_species['species_{}'.format(s)])
        avg_metrics_values = dict([(key, avg_metric.avg) for key, avg_metric in avg_metrics.items()])
        self.model.train()
        del xs, cs, ws, path_lens, true_path_lens, path_lens_diff, batch_path_lens_task, batch_owa_path_lens_task


        return avg_metrics_values



class ShortestPathAbstractTrainer(ABC):
    def __init__(
        self,
        *,
        train_iterator,
        test_iterator,
        metadata,
        use_cuda,
        batch_size,
        optimizer_name,
        optimizer_params,
        model_params,
        fast_mode,
        neighbourhood_fn,
        preload_batch,
        lr_milestone_1,
        lr_milestone_2,
        use_lr_scheduling,
        normalize_path,
        owa_weight,
        beta=0.1,
    ):

        self.fast_mode = fast_mode
        self.use_cuda = use_cuda
        self.optimizer_params = optimizer_params
        self.batch_size = batch_size
        self.test_iterator = test_iterator
        self.train_iterator = train_iterator
        self.metadata = metadata
        self.grid_dim = int(np.sqrt(self.metadata["output_features"]))
        self.neighbourhood_fn = neighbourhood_fn
        self.preload_batch = preload_batch
        self.model = None
        self.n_task = model_params.arch_params.n_task
        self.normalize_path = normalize_path
        self.beta = beta
        self.build_model(**model_params)

        self.spsolver = get_solver(neighbourhood_fn)

        if self.use_cuda:
            self.model.to("cuda")
        self.optimizer = optimizer_from_string(optimizer_name)(self.model.parameters(), **optimizer_params)
        self.use_lr_scheduling = use_lr_scheduling
        if use_lr_scheduling:
            # self.scheduler = MultiStepLR(self.optimizer, milestones=[lr_milestone_1, lr_milestone_2], gamma=0.1) #M
            #M 05/23
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, min_lr=1e-7, verbose=True, patience=70)
        self.epochs = 0
        self.iter = 0 
        # self.train_logger = Logger(scope="training", default_output="tensorboard")
        # self.val_logger = Logger(scope="validation", default_output="tensorboard")
        if owa_weight == 'one': 
            self.w_species = torch.ones(self.n_task).float()
            print('w speices', self.w_species)
        else: 
            print(self.n_task)
            self.w_species = gini_indices(self.n_task) 
            print('w speices', self.w_species)


    def train_epoch(self):
        self.epochs += 1
        batch_time = AverageMeter("Batch time")
        data_time = AverageMeter("Data time")
        cuda_time = AverageMeter("Cuda time")
        avg_loss = AverageMeter("Loss")
        avg_accuracy = AverageMeter("Accuracy")
        avg_perfect_accuracy = AverageMeter("Perfect Accuracy")
        avg_path_diff = AverageMeter("Path length diff")
        avg_owa_path_lens = AverageMeter("OWA Path length")
        avg_path_lens = AverageMeter("Path length")

        avg_metrics = customdefaultdict(lambda k: AverageMeter("train_"+k))

        batch_regret_list = []
        batch_regret_std_list = []
        batch_loss_list   = []
        batch_path_lens_diff_list = []
        batch_path_lens_list, batch_owa_path_lens_list, batch_path_lens_species_list = [],[],[]
        val_path_lens_list ,val_owa_path_lens_list, val_path_lens_diff_list, val_path_lens_species_list =[], [], [],[]
        val_path_lens_gini_list =[]


        if self.epochs == 1: 
            eval_results = self.evaluate()
            val_owa_path_lens_list.append(eval_results['owa_path_lens'])
            val_path_lens_list.append(eval_results['path_lens'])

            print("Evaluating on test set: iteration {} of epoch {}: owa_path_lens {} \t  path_lens {}".
                               format(self.iter, self.epochs, eval_results['owa_path_lens'], eval_results['path_lens']))
            if self.n_task>1:
                val_path_lens_diff_list.append(eval_results['path_lens_diff'])
                val_path_lens_species_list.append([eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ])
                val_path_lens_gini_list.append(eval_results['gini_path_lens'])
                print('Gini path len: {}'.format(eval_results['gini_path_lens']))
                print('Test: path len of each species ',[eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ])

        end = time.time()
        write_losses_interval = 20
        iterator = self.train_iterator.get_epoch_iterator(batch_size=self.batch_size, number_of_epochs=1, device='cuda' if self.use_cuda else 'cpu', preload=self.preload_batch)
        self.model.train()
        print('='*100)

        for i, data in enumerate(iterator):

            feat, true_path, true_weights = data["images"], data["labels"],  data["true_weights"]

            if i == 0:
                self.log(data, train=True)
            cuda_begin = time.time()
            cuda_time.update(time.time()-cuda_begin)

            # measure data loading time
            data_time.update(time.time() - end)


            if self.name == 'DijkstraSPO':
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(feat, true_path, true_weights, train=True, i=i)
            elif (self.name== 'DijkstraOWADescent'): 
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(feat, true_path[:, 0, :, :])
            else:
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(feat, true_path, train=True, i=i)

            suggested_path = last_suggestion["suggested_path"]
            # print('suggested_path.shape', suggested_path.shape, last_suggestion['suggested_weights'].shape)

            # JK 0630 calculate regret
            if self.name != 'DijkstraOWADescent': 
                with torch.no_grad():
                    true_path_np    = true_path.detach().numpy()
                    true_weights_np = true_weights.detach().numpy()
                    pred_weights_np = pred_weights.detach().numpy()
                    pred_path = np.asarray(  [dijkstra(wt).shortest_path for wt in pred_weights_np]  )

                    regret = (pred_path*true_weights_np).sum(2).sum(1) - (true_path_np*true_weights_np).sum(2).sum(1)
                    batch_regret_list.append(regret.mean())
                    batch_regret_std_list.append(regret.std())
                batch_metrics = metrics.compute_metrics(true_paths=true_path,
                            suggested_paths=suggested_path, true_vertex_costs=true_weights)
                # update batch metrics
                {avg_metrics[k].update(v, feat.size(0)) for k, v in batch_metrics.items()}
                assert len(avg_metrics.keys()) > 0
     
            avg_loss.update(loss.item(), feat.size(0))
            avg_accuracy.update(accuracy.item(), feat.size(0))
            batch_loss_list.append(loss.item())


            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            if self.name == 'DijkstraSPO':
                spo_weights  = (2*pred_weights.detach() - true_weights)
                grad = []
                for j in range(len(true_weights)):
                    spo_sol  = dijkstra(spo_weights[j].detach().cpu()).shortest_path
                    true_sol = dijkstra(true_weights[j].detach().cpu()).shortest_path
                    grad.append( torch.Tensor(spo_sol - true_sol) )
                pred_weights.backward(gradient=-torch.stack(grad).to(pred_weights.device))
            elif self.name == 'DijkstraDescent':
                path_lens = (suggested_path*true_weights).sum(1).sum(1)

                if self.normalize_path: 
                    true_path_lens = torch.einsum("ijk, ijk->i", true_path, true_weights)
                    path_lens = (path_lens/true_path_lens).mean()

                else: 
                    path_lens = path_lens.mean()
                path_lens.backward()
                batch_path_lens_list.append(path_lens)
                batch_owa_path_lens_list.append(path_lens.item())

                avg_path_lens.update(path_lens, feat.size(0))
                avg_owa_path_lens.update(path_lens, feat.size(0))

            #M: 05/23
            elif self.name == 'DijkstraOWADescent':
                # print('check', true_path[0, 0, :, :],true_path[0, 1, :, :],  )
                path_lens = torch.einsum("ijk, imjk->im",suggested_path, true_weights) # suggested path: B x 12 x 12, true weights:  B x n_task x 12 x2 -> B x n_task 
                if self.normalize_path: 
                    true_path_lens = torch.einsum("imjk, imjk->im", true_path, true_weights)
                    path_lens = path_lens/true_path_lens
                    # print(path_lens2)
                with torch.no_grad():
                    grad = compute_Moreau_grad_softsort(self.w_species,(-path_lens/self.beta))
                    path_lens_diff = (path_lens - path_lens.mean(1).view(-1,1)).abs()
                    batch_path_lens_species_list.append(path_lens.detach().mean(0).numpy())
                    batch_path_lens_list.append(path_lens.mean().item())
                    batch_owa_path_lens_list.append(-(compute_owa(self.w_species, -path_lens).mean()).item())
                    batch_path_lens_diff_list.append(path_lens_diff.mean(1).mean().item())
                    avg_path_diff.update(path_lens_diff.mean(1).mean().item(), feat.size(0))
                    avg_owa_path_lens.update(-(compute_owa(self.w_species, (-path_lens)).mean()).item(), feat.size(0))
                    avg_path_lens.update(path_lens.sum(dim=-1).mean(), feat.size(0))
                path_lens.backward(gradient=grad)
 
            else:
                loss.backward()

            if self.iter % write_losses_interval == 0:
                print('*'*30)
            # training_regrets.append(np.mean(epoch_regrets))
                print("Evaluating on train set: iteration {} of epoch {}:".
                       format(i, self.epochs ))


                meters = [batch_time, data_time, avg_loss, avg_accuracy,avg_path_lens]
                meter_str = "\t".join([str(meter) for meter in meters])
                print(f"Epoch: {self.epochs}\t{meter_str}")
                print("SGD lr=%.4f" % (self.optimizer.param_groups[0]["lr"]))
                if self.n_task>1:
                    print('Train: path len of each species ',np.stack(batch_path_lens_species_list).mean(0))

                eval_results = self.evaluate()
                val_owa_path_lens_list.append(eval_results['owa_path_lens'])
                val_path_lens_list.append(eval_results['path_lens'])

                print("Evaluating on test set: iteration {} of epoch {}: owa_path_lens {} \t path_lens {}".
                                   format(i, self.epochs, eval_results['owa_path_lens'], eval_results['path_lens']))
                if self.n_task>1:
                    val_path_lens_diff_list.append(eval_results['path_lens_diff'])
                    val_path_lens_species_list.append([eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ])
                    val_path_lens_gini_list.append(eval_results['gini_path_lens'])
                    print('Test: path len of each species ',[eval_results['path_lens_species_{}'.format(i)] for i in range(self.n_task) ])

                # if self.use_lr_scheduling:
                #     self.scheduler.step(eval_results['path_lens'].avg) #M
            self.optimizer.step()

            #M
            if self.use_lr_scheduling:
                self.scheduler.step(avg_path_lens.avg) #M
                
            self.iter +=1
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()


            if self.fast_mode:
                break

        # if self.use_lr_scheduling:
            # self.scheduler.step()

        # self.train_logger.log(avg_loss.avg, "loss")
        # self.train_logger.log(avg_accuracy.avg, "accuracy")
        # for key, avg_metric in avg_metrics.items():
        #     self.train_logger.log(avg_metric.avg, key=key)

        return {
            "batch_regret_list":batch_regret_list,
            "batch_regret_std_list":batch_regret_std_list,
            "batch_loss_list":batch_loss_list,
            "batch_path_lens_list": batch_path_lens_list, 
            "batch_owa_path_lens_list": batch_owa_path_lens_list, 
            "batch_path_lens_diff_list": batch_path_lens_diff_list,
            "batch_path_lens_species": np.stack(batch_path_lens_species_list).mean(0) if self.n_task >1 else batch_path_lens_species_list,
            "train_loss": avg_loss.avg,
            "train_accuracy": avg_accuracy.avg,
            "val_path_lens_list": val_path_lens_list, 
            "val_owa_path_lens_list": val_owa_path_lens_list, 
            "val_path_lens_diff_list": val_path_lens_diff_list, 
            "val_path_lens_species_list": val_path_lens_species_list, 
            "val_path_lens_gini_list": val_path_lens_gini_list, 
            **{"train_"+k: avg_metrics[k].avg for k in avg_metrics.keys()}
        }

    def evaluate(self):
        avg_metrics = defaultdict(AverageMeter)

        self.model.eval()

        iterator = self.test_iterator.get_epoch_iterator(batch_size=self.batch_size, number_of_epochs=1, shuffle=False, device='cuda' if self.use_cuda else 'cpu', preload=self.preload_batch)
        batch_path_lens_task = []
        for i, data in enumerate(iterator):
            feat, true_path, true_weights = (
                data["images"].contiguous(),
                data["labels"].contiguous(),
                data["true_weights"].contiguous(),
            )

            if self.use_cuda:
                feat = feat.cuda(non_blocking=True)   # JK deprecated keyword fix
                true_path = true_path.cuda(non_blocking=True)

            if self.name == 'DijkstraSPO':
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(feat, true_path, true_weights, train=False, i=i)
            elif self.name== 'DijkstraOWADescent': 
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(feat, true_path[:, 0, :, :])
            else:
                loss, accuracy, last_suggestion, pred_weights = self.forward_pass(feat, true_path, train=False, i=i)

            suggested_path = last_suggestion["suggested_path"]
            data.update(last_suggestion)
            
            # JK 06/30
            with torch.no_grad():
                true_path_np    = true_path.detach().numpy()
                true_weights_np = true_weights.detach().numpy()
                pred_weights_np = pred_weights.detach().numpy()
                pred_path = np.asarray(  [dijkstra(wt).shortest_path for wt in pred_weights_np]  )

                if self.name != 'DijkstraOWADescent': 

                    regret = (pred_path*true_weights_np).sum(2).sum(1) - (true_path_np*true_weights_np).sum(2).sum(1)
                    evaluated_metrics = metrics.compute_metrics(true_paths=true_path,
                                                                    suggested_paths=suggested_path, true_vertex_costs=true_weights)

                    path_lens = torch.einsum("ijk, ijk->i",suggested_path, true_weights)
                    if self.normalize_path: 
                        true_path_lens = torch.einsum("ijk, ijk->i", true_path, true_weights)
                        path_lens = (path_lens/true_path_lens).mean()
                    avg_metrics["regret"].update(regret.mean(), feat.size(0))  # JK 0630
                    for key, value in evaluated_metrics.items():
                        avg_metrics[key].update(value, feat.size(0))

                else:
                    path_lens = torch.einsum("ijk, imjk->im",suggested_path, true_weights)

                    if self.normalize_path: 
                        true_path_lens = torch.einsum("imjk, imjk->im", true_path, true_weights)
                        path_lens = path_lens/true_path_lens

                    path_lens_diff = (path_lens - path_lens.mean(1).view(-1,1)).abs()
                    batch_path_lens_task.extend(path_lens.numpy())
                # hamming = HammingLoss()( torch.Tensor(true_path_np[:, 0, :, :]),  torch.Tensor(pred_path)  )

            # avg_metrics["loss"].update(hamming, feat.size(0))          # JK 0630
            avg_metrics["accuracy"].update(accuracy.item(), feat.size(0))

            if self.n_task>1:
                avg_metrics["path_lens"].update(path_lens.sum(dim=-1).mean().item(),feat.size(0))
                avg_metrics["owa_path_lens"].update(-(compute_owa(self.w_species, (-path_lens)).mean()).item(), feat.size(0))   # M 5/23
                avg_metrics["path_lens_diff"].update(path_lens_diff.mean(1).mean().item(), feat.size(0))   # M 05/23
                for i in range(self.n_task): 
                    avg_metrics['path_lens_species_{}'.format(i)].update(np.stack(batch_path_lens_task)[:,i].mean(), feat.size(0)) 
                avg_metrics['gini_path_lens'].update(gini_coefficient(path_lens.numpy()).mean(), feat.size(0))
            # JK avg_metrics["pathlen"].update(
            else:
                avg_metrics["path_lens"].update(path_lens.item(),feat.size(0))

                avg_metrics["owa_path_lens"].update(path_lens.mean().item(), feat.size(0))   # M 5/23

            if self.fast_mode:
                break

        # for key, avg_metric in avg_metrics.items():
        #     self.val_logger.log(avg_metric.avg, key=key)
        avg_metrics_values = dict([(key, avg_metric.avg) for key, avg_metric in avg_metrics.items()])
        self.model.train()

        return avg_metrics_values

    @abstractmethod
    def build_model(self, **kwargs):
        pass

    @abstractmethod
    def forward_pass(self, feat, true_shortest_paths, train, i):
        pass

    def log(self, data, train, k=None, num=None):
        # logger = self.train_logger if train else self.val_logger
        if not train:
            image = self.metadata['denormalize'](data["images"][k]).squeeze().astype(np.uint8)
            suggested_path = data["suggested_path"][k].squeeze()
            labels = data["labels"][k].squeeze()

            suggested_path_im = torch.ones((3, *suggested_path.shape))*255*suggested_path.cpu()
            labels_im = torch.ones((3, *labels.shape))*255*labels.cpu()
            image_with_path = draw_paths_on_image(image=image, true_path=labels, suggested_path=suggested_path, scaling_factor=10)

            # logger.log(labels_im.data.numpy().astype(np.uint8), key=f"shortest_path_{num}", data_type="image")
            # logger.log(suggested_path_im.data.numpy().astype(np.uint8), key=f"suggested_path_{num}", data_type="image")
            # logger.log(image_with_path, key=f"full_input_with_path{num}", data_type="image")



class BaselineTrainer(ShortestPathAbstractTrainer):
    def build_model(self, model_name, arch_params):
        grid_dim = int(np.sqrt(self.metadata["output_features"]))
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )
        self.name = 'BaselineTrainer'

    def forward_pass(self, feat, label, train, i):
        output = self.model(feat)
        output = torch.sigmoid(output)
        weights = output.reshape(-1, output.shape[-1], output.shape[-1])

        flat_target = label.view(label.size()[0], -1)

        criterion = torch.nn.BCELoss()
        loss = criterion(output, flat_target).mean()
        accuracy = (output.round() * flat_target).sum() / flat_target.sum()

        suggested_path = output.view(label.shape).round()
        last_suggestion = {"vertex_costs": None, "suggested_path": suggested_path}

        return loss, accuracy, last_suggestion, weights






class DijkstraOnFull(ShortestPathAbstractTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraOnFull'

        # print("META:", self.metadata)
    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths, train, i):
        output = self.model(feat)
        # make grid weights positive
        output = torch.abs(output)
        weights = output.reshape(-1, output.shape[-1], output.shape[-1])

        if i == 0 and not train:
            print(output[0])
        assert len(weights.shape) == 3, f"{str(weights.shape)}"


        shortest_paths = self.solver(weights)

        loss = self.loss_fn(shortest_paths, true_shortest_paths)
        # logger = self.train_logger if train else self.val_logger

        last_suggestion = {
            "suggested_weights": weights,
            "suggested_path": shortest_paths
        }

        accuracy = (torch.abs(shortest_paths - true_shortest_paths) < 0.5).to(torch.float32).mean()
        extra_loss = self.l1_regconst * torch.mean(output)
        loss += extra_loss

        return loss, accuracy, last_suggestion, weights




class DijkstraSPO(ShortestPathAbstractTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraSPO'

        print("META:", self.metadata)
    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths, true_weights, train, i): # JK 2/23 added true_weights
        output = self.model(feat)
        # make grid weights positive
        output = torch.abs(output)
        weights = output.reshape(-1, output.shape[-1], output.shape[-1])

        if i == 0 and not train:
            print(output[0])
        assert len(weights.shape) == 3, f"{str(weights.shape)}"

        #print("weights.shape = ")
        #print( weights.shape )

        shortest_paths = self.solver(weights)

        loss = self.loss_fn(shortest_paths, true_shortest_paths)

        # logger = self.train_logger if train else self.val_logger

        last_suggestion = {
            "suggested_weights": weights,
            "suggested_path": shortest_paths
        }

        accuracy = (torch.abs(shortest_paths - true_shortest_paths) < 0.5).to(torch.float32).mean()
        extra_loss = self.l1_regconst * torch.mean(output)
        loss += extra_loss

        return loss, accuracy, last_suggestion, weights   # JK 02/23 added weights




class DijkstraDescent(ShortestPathAbstractTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraDescent'

        print("META:", self.metadata)
    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths, train, i):
        output = self.model(feat)
        # make grid weights positive
        # output = torch.abs(output)
        weights = output.reshape(-1, output.shape[-1], output.shape[-1])
  

        assert len(weights.shape) == 3, f"{str(weights.shape)}"

        shortest_paths = self.solver.apply(weights, self.spsolver, self.lambda_val)
        # if i == 0 and not train:
            # print(output[0])
        # print('shortest_paths', shortest_paths.shape, true_shortest_paths.shape)
        loss = self.loss_fn(shortest_paths, true_shortest_paths)

        # logger = self.train_logger if train else self.val_logger

        last_suggestion = {
            "suggested_weights": weights,
            "suggested_path": shortest_paths
        }

        accuracy = (torch.abs(shortest_paths - true_shortest_paths) < 0.5).to(torch.float32).mean()
        extra_loss = self.l1_regconst * torch.mean(output)
        loss += extra_loss

        return loss, accuracy, last_suggestion, weights


#M
class DijkstraOWADescent(ShortestPathAbstractTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraOWADescent'

    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths):
        output = self.model(feat)
        # make grid weights positive
        # output = torch.abs(output)
        weights = output.reshape(-1, output.shape[-1], output.shape[-1])

        assert len(weights.shape) == 3, f"{str(weights.shape)}"

        #print("weights.shape = ")
        #print( weights.shape )
        shortest_paths = self.solver.apply(weights, self.spsolver, self.lambda_val)

        loss = self.loss_fn(shortest_paths, true_shortest_paths)

        # logger = self.train_logger if train else self.val_logger

        last_suggestion = {
            "suggested_weights": weights,
            "suggested_path": shortest_paths
        }

        accuracy = (torch.abs(shortest_paths - true_shortest_paths) < 0.5).to(torch.float32).mean()
        extra_loss = self.l1_regconst * torch.mean(output)
        loss += extra_loss

        return loss, accuracy, last_suggestion, weights


class BaselineTrainerMulti(ShortestPathAbstractMulitTrainer): # two stage
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.loss_fn = torch.nn.MSELoss()
        self.name = 'BaselineMulti'
        self.solver = dijkstra

    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_weights):
        lst_shortest_paths, lst_weights  = [], []
        accuracy, loss_total = 0, 0
        suggested_path = torch.zeros((self.n_task,feat[0].shape[0], 12, 12))

        loss = [] 
        for i in range(self.n_task):
            output = self.model(feat[i])[i]
            weights = output.reshape(-1, 12, 12) # B x 12 x12
            loss.append(self.loss_fn(weights, true_weights[i]))
            lst_weights.append(weights)
            with torch.no_grad(): 
                # print('weights', weights.shape)
                for s in range(len(weights)): 
                    cur_path = dijkstra(weights[s].detach().cpu().numpy()).shortest_path
                    suggested_path[i, s, :, : ] = torch.tensor(cur_path)
                # print('weights.shape', weights.shape, true_weights.shape)
                accuracy +=(torch.abs(weights - true_weights[i]) < 0.5).to(torch.float32).mean()

        all_weights = torch.stack(lst_weights)
        last_suggestion = {
            "suggested_weights": all_weights,
            "suggested_path": suggested_path
        }
        return torch.stack(loss).sum(), accuracy, last_suggestion, all_weights


class DijkstraMultiOWADescent(ShortestPathAbstractMulitTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraMultiOWADescent'

    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths):
        lst_shortest_paths, lst_weights  = [], []
        accuracy, loss_total = 0, 0
        for i in range(self.n_task):
            output = self.model(feat[i])[i]
            # output = torch.abs(output)

            # print('out', output.shape)
            # print(output[0])
            weights = output.reshape(-1, 12, 12) # B x 12 x12
            shortest_paths = self.solver.apply(weights, self.spsolver, self.lambda_val)
            # print('suggested path inside', shortest_paths[0])
            loss = self.loss_fn(shortest_paths, true_shortest_paths[i])

            lst_shortest_paths.append(shortest_paths)
            lst_weights.append(weights)
            loss_total += loss
            accuracy +=(torch.abs(shortest_paths - true_shortest_paths[i]) < 0.5).to(torch.float32).mean()

        all_weights = torch.stack(lst_weights)
        all_shortest_path = torch.stack(lst_shortest_paths)
        last_suggestion = {
            "suggested_weights": all_weights,
            "suggested_path": all_shortest_path
        }

        return loss/self.n_task, accuracy/self.n_task, last_suggestion, all_weights


class DijkstraMultiOWADescent2(ShortestPathAbstractMulitTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraMultiOWADescent2'

    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths):
        lst_shortest_paths, lst_weights  = [], []
        accuracy, loss_total = 0, 0
        for i in range(self.n_task):
            output = self.model(feat[i])[i]
            # output = torch.abs(output)

            # print('out', output.shape)
            # print(output[0])
            weights = output.reshape(-1, 12, 12) # B x 12 x12
            shortest_paths = self.solver.apply(weights, self.spsolver, self.lambda_val)
            # print('suggested path inside', shortest_paths[0])
            loss = self.loss_fn(shortest_paths, true_shortest_paths[i])

            lst_shortest_paths.append(shortest_paths)
            lst_weights.append(weights)
            loss_total += loss
            accuracy +=(torch.abs(shortest_paths - true_shortest_paths[i]) < 0.5).to(torch.float32).mean()

        all_weights = torch.stack(lst_weights)
        all_shortest_path = torch.stack(lst_shortest_paths)
        last_suggestion = {
            "suggested_weights": all_weights,
            "suggested_path": all_shortest_path
        }

        return loss/self.n_task, accuracy/self.n_task, last_suggestion, all_weights


class DijkstraMultiDescent(ShortestPathAbstractMulitTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraMultiDescent'

    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths):
        lst_shortest_paths, lst_weights  = [], []
        accuracy, loss_total = 0, 0
        for i in range(self.n_task):
            output = self.model(feat[i])[i]
            # output = torch.abs(output)
            weights = output.reshape(-1, 12, 12) # B x 12 x12
            shortest_paths = self.solver.apply(weights, self.spsolver, self.lambda_val)
            # print('suggested path inside', shortest_paths[0])
            loss = self.loss_fn(shortest_paths, true_shortest_paths[i])

            lst_shortest_paths.append(shortest_paths)
            lst_weights.append(weights)
            loss_total += loss
            accuracy +=(torch.abs(shortest_paths - true_shortest_paths[i]) < 0.5).to(torch.float32).mean()

        all_weights = torch.stack(lst_weights)
        all_shortest_path = torch.stack(lst_shortest_paths)
        last_suggestion = {
            "suggested_weights": all_weights,
            "suggested_path": all_shortest_path
        }

        return loss/self.n_task, accuracy/self.n_task, last_suggestion, all_weights


class DijkstraMultiGradNormDescent(ShortestPathAbstractMulitTrainer):
    def __init__(self, *, l1_regconst, lambda_val, **kwargs):
        super().__init__(**kwargs)
        self.l1_regconst = l1_regconst
        self.lambda_val = lambda_val
        self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn)
        self.loss_fn = HammingLoss()
        self.name = 'DijkstraMultiGradNormDescent'

    def build_model(self, model_name, arch_params):
        self.model = get_model(
            model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params
        )

    def forward_pass(self, feat, true_shortest_paths):
        lst_shortest_paths, lst_weights  = [], []
        accuracy, loss_total = 0, 0
        for i in range(self.n_task):
            output = self.model(feat[i])[i]
            # output = torch.abs(output)
            weights = output.reshape(-1, 12, 12) # B x 12 x12
            shortest_paths = self.solver.apply(weights, self.spsolver, self.lambda_val)
            # print('suggested path inside', shortest_paths[0])
            loss = self.loss_fn(shortest_paths, true_shortest_paths[i])

            lst_shortest_paths.append(shortest_paths)
            lst_weights.append(weights)
            loss_total += loss
            accuracy +=(torch.abs(shortest_paths - true_shortest_paths[i]) < 0.5).to(torch.float32).mean()

        all_weights = torch.stack(lst_weights)
        all_shortest_path = torch.stack(lst_shortest_paths)
        last_suggestion = {
            "suggested_weights": all_weights,
            "suggested_path": all_shortest_path
        }

        return loss/self.n_task, accuracy/self.n_task, last_suggestion, all_weights



class MoreauOWALossLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, multi_obj, w_gini, beta):
        ctx.w_gini = w_gini
        ctx.beta = beta
        multi_obj_sigma, sigma = (-multi_obj).sort(descending=True,dim=-1)# B x M
        owa_loss = torch.einsum("m, im-> i", w_gini, torch.sort(multi_obj, dim=-1).values) # Compute OWa value of multi objective
        ctx.save_for_backward(multi_obj)
        return owa_loss

    @staticmethod
    def backward(ctx, grad_output):
        multi_obj,= ctx.saved_tensors
        w_gini = ctx.w_gini
        beta = ctx.beta
        z = (-multi_obj/beta)
        grad = compute_Moreau_grad_softsort(w_gini, z)
        return grad.to(grad_output.device), None, None 