from argparse import ArgumentError
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Callable, List, Optional, Set, Tuple, Union

import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import matplotlib.pyplot as plt

from omegaconf import DictConfig
from scipy.spatial.distance import cdist
from torch.utils.data import DataLoader, Dataset

import timer_logger
import utils
from logging import getLogger
import Hyperparameters_Tunnel
import revtorch as rv
import NEWTWDE_Model as Model
from torch.distributions.categorical import Categorical
from datasets import Lower_TSPDataset,Lower_TSPTunnel
from torch.optim.lr_scheduler import MultiStepLR as Scheduler
import torch.optim as optim


class TunnelState:
    def __init__(self,group_size,nodes,tunnels):
        self.batch_size = nodes.size(0)
        self.graph_size = nodes.size(1)
        self.tunnel_source = tunnels[:,:,0]
        self.tunnel_target = tunnels[:,:,1]
        self.num_tunnels = tunnels.size(1)
        self.group_size = group_size
        self.device = nodes.device

        self.selected_count = 0
        self.current_node = None
        # selected_node_list.shape = [B, G, selected_count]
        self.selected_node_list = torch.zeros(
            nodes.size(0), group_size, 0, device=nodes.device
        ).long()
        # ninf_mask.shape = [B, G, N]
        self.ninf_mask = torch.zeros(nodes.size(0), group_size, nodes.size(1), device=nodes.device)
        self.tunnel_ninf_mask = torch.zeros(nodes.size(0),group_size,self.num_tunnels,device=nodes.device)


    def move_to(self,selected_idx_mat):
        #selected_idx_mat.shape = [B, G]
        #print(selected_idx_mat[:,0])
        self.selected_count += 1
        self.__move_to(selected_idx_mat)
        next_selected_idx_mat = selected_idx_mat
        next_selected_idx_mat = self.__connect_source_target_city(next_selected_idx_mat)
        if (selected_idx_mat != next_selected_idx_mat).any():
            self.__move_to(next_selected_idx_mat)


    def __move_to(self,selected_idx_mat):

        self.current_node = selected_idx_mat
        #print(selected_idx_mat[:,:,None].shape)
        self.selected_node_list = torch.cat(
            (self.selected_node_list, selected_idx_mat[:, :, None]), dim=2
        )
        self.ninf_mask.scatter_(
            dim=-1, index=selected_idx_mat[:, :, None], value=-torch.inf
        )

    def __connect_source_target_city(self, selected_idx_mat):
        bsz,len_st = self.tunnel_source.shape
        _,m = selected_idx_mat.shape
        Gsource = self.tunnel_source.unsqueeze(1).expand(bsz,m, len_st).long().to(self.device)
        Gtarget = self.tunnel_target.unsqueeze(1).expand(bsz,m, len_st).long().to(self.device)
        Ginput = selected_idx_mat.unsqueeze(2).expand(bsz,m, len_st).to(self.device)
        source_match = (Gsource == Ginput).nonzero(as_tuple=True)
        target_match = (Gtarget == Ginput).nonzero(as_tuple=True)
        EXSinput = Ginput.clone()  
        EXGinput = Ginput.clone() 
        EXSinput[source_match] = Gtarget[source_match]
        EXGinput[target_match] = Gsource[target_match]
        Minus = torch.sum(EXGinput+EXSinput-2*Ginput,dim=2)
        output = selected_idx_mat + Minus
        return output
    

class TunnelTSPEnv:
    def __init__(self,nodes,tunnels,device):
        self.device = device
        self.nodes = nodes
        self.tunnels = tunnels
        self.tunnel_source = tunnels[:,:,0]
        self.tunnel_target = tunnels[:,:,1]
        self.batch_size = nodes.size(0)
        self.graph_size = nodes.size(1)
        self.node_dim = nodes.size(2)
        self.tunnel_batch = tunnels.size(0)
        self.num_tunnels = tunnels.size(1)

        if self.batch_size != self.tunnel_batch:
            raise ValueError("Nodes dim inequals Tunnels dim")

    def reset(self,group_size):
        self.group_size = group_size
        self.group_state = TunnelState(
            group_size = group_size, nodes=self.nodes, tunnels= self.tunnels
        )
        self.fixed_edge_length = 0
        for i in range(self.num_tunnels):
            self.fixed_edge_length += self._get_fixed_length(source=self.tunnel_source[:,i].unsqueeze(-1),
                                                             target=self.tunnel_target[:,i].unsqueeze(-1))

        reward = None
        done = None
        return self.group_state,reward,done

    def step(self,selected_idx_mat):

        self.group_state.move_to(selected_idx_mat.cpu())
        #correspound_idx_mat = (selected_idx_mat // 2) * 4 + 1 - selected_idx_mat
        #self.group_state.move_to(correspound_idx_mat.cpu())
        done = self.group_state.selected_count == self.num_tunnels
        if done:
            reward = -self._get_path_distance()
        else:
            reward = None
        return self.group_state, reward, done
    
    def _get_fixed_length(self,source,target):
        idx_shp = (self.batch_size, self.group_size, 1, self.node_dim)#[B,G,1,C]
        coord_shp = (self.batch_size, self.group_size, self.graph_size, self.node_dim)#[B,G,N,C]
        source_idx = source[..., None, None].expand(*idx_shp)
        target_idx = target[..., None, None].expand(*idx_shp)
        fixed_edge_idx = torch.cat([source_idx, target_idx], dim=2).to(self.device).long()#[B,G,2,C]
        seq_expanded = self.nodes[:, None, :, :].expand(*coord_shp).to(self.device)#[B,G,N,C]
        ordered_seq = seq_expanded.gather(dim=2, index=fixed_edge_idx)
        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        delta = (ordered_seq - rolled_seq)[:, :, :-1, :]
        edge_length = (delta**2).sum(3).sqrt().sum(2)
        return edge_length

    def _get_path_distance(self) -> torch.Tensor:
        # selected_node_list.shape = [B, G, selected_count]
        interval = (
            torch.tensor([-1], device=self.nodes.device)
            .long()
            .expand(self.batch_size, self.group_size)
        )
        selected_node_list = torch.cat(
            (self.group_state.selected_node_list, interval[:, :, None]),
            dim=2,
        ).flatten()
        unique_selected_node_list = selected_node_list.unique_consecutive()
        assert unique_selected_node_list.shape[0] == (
            self.batch_size * self.group_size * (self.graph_size + 1)
        ), unique_selected_node_list.shape
        unique_selected_node_list = unique_selected_node_list.view(
            [self.batch_size, self.group_size, -1]
        )[..., :-1]
        shp = (self.batch_size, self.group_size, self.graph_size, self.node_dim)
        gathering_index = unique_selected_node_list.unsqueeze(3).expand(*shp)
        seq_expanded = self.nodes[:, None, :, :].expand(*shp)
        ordered_seq = seq_expanded.gather(dim=2, index=gathering_index)
        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        delta = ordered_seq - rolled_seq
        tour_distances = (delta**2).sum(3).sqrt().sum(2)
        # minus the length of the fixed edge
        #print('tour',tour_distances.device)
        #print('fixed',self.fixed_edge_length.device)
        path_distances = tour_distances - self.fixed_edge_length.cpu()
        return path_distances
            

class TunnelTSPrunning(nn.Module):
    def __init__(self, env_params,lower_params,
                               running_params,optimizer_params,device):
        super().__init__()
        self.device = device
        self.env_params = env_params
        self.model_params = lower_params
        self.trainer_params = running_params
        self.optimizer_params = optimizer_params
        
        USE_CUDA = self.trainer_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.trainer_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')

        self.model = Model.TSPModel(model_params = self.model_params,
                                    device = device)
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr = optimizer_params['learning_rate'])
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size = 1,gamma = 1.0)#调整
        
        self.logger = getLogger(name='trainer')
        self.result_folder = timer_logger.get_result_folder()
        self.result_log = timer_logger.LogData()
        
        self.start_epoch = 1
        self.start_from_ckpt = self.trainer_params['start_from_ckpt']
        self.timer = timer_logger.TimeEstimator()

    def forward(self,batch_nodes,batch_tunnel,return_pi = False, val_type = None):
        score_eval = utils.AverageMeter()
        Bsize = self.trainer_params['train_batch_size']
        B,N,_ = batch_nodes.shape
        start = 0
        while start < B:
            r = self._forward(batch_nodes[start:start+Bsize],batch_tunnel[start:start+Bsize],return_pi, val_type)
            length_mean = -sum(r[0])/Bsize
            score_eval.update(length_mean,Bsize)
            start = start+Bsize
        return -score_eval.avg
        
    def _forward(self,batch_nodes,batch_tunnel,return_pi = False, val_type = None):    
        B,N,_ = batch_nodes.shape
        val_type = val_type or self.model_params['val_type']
        G = 1 if val_type == "noAug_1Traj" else self.env_params['group_size']
        assert G <= self.env_params['graph_size']
        batch_tunnel_env = utils.expand_all_as_tunnels(N,batch_tunnel)
        batch_tunnel = utils.expand_every_tunnels(N,batch_tunnel)
        env = TunnelTSPEnv(batch_nodes,batch_tunnel_env,device=self.device)
        s,r,d = env.reset(group_size = G)
        batch_nodes = batch_nodes.to(self.device)
        batch_tunnel = batch_tunnel.to(self.device)
        batch_tunnel_env = batch_tunnel_env.to(self.device)
        batch_coord_tunnel = utils.generate_coord_from_indexes(batch_nodes,batch_tunnel_env)
        #print(batch_coord_nodes[0,0])
        if self.model_params['data_augment']:
            #batch = utils.augment_xy_data_by_8_fold(batch,training=True)#input_dim=16
            #batch_nodes = utils.data_augment(batch_nodes) #input-dim=24
            batch_nodess = utils.augment_xy_data_by_8_fold_POMO(batch_nodes,training=True)#input_dim=32
            batch_coord_tunnel = utils.augment_tunnel_data_by_8_fold(batch_coord_tunnel,training=True)
            print(batch_coord_tunnel.shape)
        embeddings_nodes = self.model.encoder_nodes(batch_nodess)  
        embeddings_tunnels = self.model.encoder_tunnels(batch_coord_tunnel)
        
        # CONSIDER DIRECTION, ELSE, BATCH_COORD_TUNNEL=(..,batchtunnel_env)
        _,L,_ = embeddings_nodes.shape
        tunnel_table = utils.create_output_matrix_with_batch(batch_tunnel_env,L)
        tunnel_table = torch.tensor(tunnel_table).to(self.device).double()
        #self.logit_k_tunnels = self.adapt_logit_k_tunnel(self.embeddings_tunnels.transpose(1,2))
        logit_k_tunnels = torch.bmm(embeddings_tunnels.transpose(1,2).double(),tunnel_table).float()
        self.model.pre_forward(embeddings_nodes,logit_k_tunnels.mT,batch_tunnel_env.to(self.device))

        # WITHOUT CONSIDER, else, BATCH_COORD_TUNNEL = (..,batchtunnel)
        #self.model.pre_forward(embeddings_nodes,embeddings_tunnels,batch_tunnel_env.to(self.device))

        prob_list = torch.zeros((B,G,0),device = self.device)
        first_step = torch.randperm(N)[None, :G].expand(B, G)
        pi = first_step[..., None]
        batch_idx_range = torch.arange(B)[:, None].expand(B, G).to(device=self.device)
        group_idx_range = torch.arange(G)[None, :].expand(B, G).to(device=self.device)
        while not d:
            action,prob = self.model(s,B,G,batch_idx_range,group_idx_range)
            s,r,d = env.step(action)
            pi = torch.cat([pi, action[..., None]], dim=-1)
            prob_list = torch.cat((prob_list,prob[:,:,None]),dim = 2)    
        pi = pi.cpu()
        if val_type == "noAug_1Traj":
            max_reward = r
            best_pi = pi
        elif val_type == "noAug_nTraj":
            max_reward, idx_dim_1 = r.max(dim=1)
            idx_dim_1 = idx_dim_1.reshape(B, 1, 1)
            best_pi = pi.gather(1, idx_dim_1.repeat(1, 1, N-self.env_params['tunnel_per_graph']))
        else:
            B = round(B / 8)
            
            reward = r.reshape(8, B, G)
            max_reward, idx_dim_2 = reward.max(dim=2)
            max_reward, idx_dim_0 = max_reward.max(dim=0)
            pi = pi.reshape(8, B, G, N)
            idx_dim_0 = idx_dim_0.reshape(1, B, 1, 1)
            idx_dim_2 = idx_dim_2.reshape(8, B, 1, 1).gather(0, idx_dim_0)
            best_pi = pi.gather(0, idx_dim_0.repeat(1, 1, G, N))
            best_pi = best_pi.gather(2, idx_dim_2.repeat(1, 1, 1, N))

        if return_pi:
            return -max_reward, best_pi.squeeze()
        return -max_reward


    def train_dataloader(self):
        self.group_size = self.env_params['group_size']
        dataset_nodes = Lower_TSPDataset(
            size=self.env_params['graph_size'],
            node_dim=self.env_params['node_dim'],
            num_samples=self.trainer_params['train_size'],
            data_distribution=self.env_params['data_distribution'],
        )
        print(dataset_nodes.data.device)
        dataloader_nodes = DataLoader(
            dataset_nodes,
            num_workers=0,
            batch_size=self.trainer_params['train_batch_size'],
            #pin_memory=True,
        )

        dataset_tunnels = Lower_TSPTunnel(
            size=self.env_params['graph_size'],
            tunnels = self.env_params['tunnel_per_graph'],
            node_dim=self.env_params['node_dim'],
            num_samples=self.trainer_params['train_size'],
        )
        dataloader_tunnels = DataLoader(
            dataset_tunnels,
            num_workers=0,
            batch_size=self.trainer_params['train_batch_size'],
            #pin_memory=True,
        )

        return dataloader_nodes,dataloader_tunnels
    
    #This _getdevicelogger function is just for debugging. Not used in the process.
    def _getdevicelogger(self,params,str_params):
        device = params.device
        device_index = None
        if device.type == 'cuda':
            device_index = device.index
        device_info = f"{str_params}'s Device: {device.type}"
        if device_index is not None:
            device_info += f", Device index: {device_index}"
        self.logger.info(device_info)  

    def read_ckpt(self,):
        model_load = {
            'enable': True,  # enable loading pre-trained model
            'path': 'result/20240930_183423_train__tsp_n20',  # directory path of pre-trained model and log files saved.
            'epoch': 500,  # epoch version of pre-trained model to laod.
        }
        if model_load['enable']:
            checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
            checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.start_epoch = 1 + model_load['epoch']
            self.result_log.set_raw_data(checkpoint['result_log'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.last_epoch = model_load['epoch']-1
            self.logger.info('Saved Model Loaded !!')

    def start_train_tunnel(self):
        self.timer.reset(self.start_epoch)
        if self.start_from_ckpt == True:
            self.read_ckpt()
        train_dataset_nodes_loader,train_dataset_tunnel_loader = self.train_dataloader()
        
        self.train_lossmean,self.train_scoremean,self.train_rmaxloss,self.train_advmean = [],[],[],[]

        for epoch in range(self.start_epoch,self.trainer_params['total_epoch']+1):
            if epoch % self.trainer_params['log_per_epoch'] == 0:
                self.logger.info('=================================================================')
            self._train_one_epoch(epoch = epoch,
                                  train_dataset_nodes_loader = train_dataset_nodes_loader,
                                  train_dataset_tunnel_loader = train_dataset_tunnel_loader)
            if epoch % self.trainer_params['log_per_epoch'] == 0:
                self.timer.print_est_time(epoch,self.trainer_params['total_epoch'])
            #self.plot_trend(self.train_lossmean,'mean')
            #self.plot_trend(self.train_scoremean,'score')
            #self.plot_trend(self.train_rmaxloss,'rmax')
            #self.plot_trend(self.train_advmean,'adv')

            #all_done = (epoch == self.trainer_params['total_epoch'])
            #if all_done:
            #    self.logger.info(" *** Training Done *** ")
            #    self.logger.info("Now, printing log array...")
            #    timer_logger.util_print_log_array(self.logger, self.result_log)
        self.logger.info('*****   TRAINING POMO-TUNNEL OVER   *****')
    
    def _train_one_epoch(self,epoch,train_dataset_nodes_loader,train_dataset_tunnel_loader):
        loss_onepoch = utils.AverageMeter()
        score_onepoch = utils.AverageMeter()
        rmax_onepoch = utils.AverageMeter()
        adv_onepoch = utils.AverageMeter()
        self.model.train()
        episodes = 0
        for batch_nodes,batch_tunnel in zip(train_dataset_nodes_loader,train_dataset_tunnel_loader):
            B, N, _ = batch_nodes.shape
            G = self.group_size
            batch_idx_range = torch.arange(B)[:, None].expand(B, G).to(device=self.device)
            group_idx_range = torch.arange(G)[None, :].expand(B, G).to(device=self.device)  
            batch_tunnel_env = utils.expand_all_as_tunnels(N,batch_tunnel)
            batch_tunnel = utils.expand_every_tunnels(N,batch_tunnel)
            #batch_tunnel = self.batch_tunnel.to(self.device).repeat(B//8,1,1)
            env = TunnelTSPEnv(batch_nodes,batch_tunnel_env,device=self.device)
            s, r, d = env.reset(group_size = G)
            batch_nodes = batch_nodes.to(self.device)
            batch_tunnel = batch_tunnel.to(self.device)
            batch_tunnel_env = batch_tunnel_env.to(self.device)
            batch_coord_tunnel = utils.generate_coord_from_indexes(batch_nodes,batch_tunnel_env)
            if self.model_params['data_augment']:
                #batch = utils.augment_xy_data_by_8_fold(batch,training=True)
                #batch_nodes = utils.data_augment(batch_nodes)
                batch_nodes = utils.augment_xy_data_by_8_fold_POMO(batch_nodes,training=True)
                batch_coord_tunnel = utils.augment_tunnel_data_by_8_fold(batch_coord_tunnel,training=True)
            #Encoding
            embeddings_nodes = self.model.encoder_nodes(batch_nodes)
            embeddings_tunnels = self.model.encoder_tunnels(batch_coord_tunnel)
            #WITH DIFFRERNT TUNNELS. m-n
            _,L,_ = embeddings_nodes.shape
            tunnel_table = utils.create_output_matrix_with_batch(batch_tunnel_env,L)
            tunnel_table = torch.tensor(tunnel_table).to(self.device).double()
            #self.logit_k_tunnels = self.adapt_logit_k_tunnel(self.embeddings_tunnels.transpose(1,2))
            logit_k_tunnels = torch.bmm(embeddings_tunnels.transpose(1,2).double(),tunnel_table).float()
            self.model.pre_forward(embeddings_nodes,logit_k_tunnels.mT,batch_tunnel_env.to(self.device))
            
            #WITHOUT
            #self.model.pre_forward(embeddings_nodes,embeddings_tunnels,batch_tunnel_env.to(self.device))

            # first_step = torch.randperm(N)[None, :G].expand(B, G).to(self.device)
            #self.model.pre_forward(embeddings_nodes,embeddings_tunnels,batch_tunnel_env.to(self.device))
            prob_list = torch.zeros((B,G,0),device = self.device)
            first_step = torch.randperm(N)[None, :G].expand(B, G)
            pi = first_step[..., None]
            log_prob = torch.zeros(B,G,device = self.device)
            while not d:
                action,prob = self.model(s,B,G,batch_idx_range,group_idx_range)
                s,r,d = env.step(action)
                pi = torch.cat([pi, action[..., None]], dim=-1)
                prob_list = torch.cat((prob_list,prob[:,:,None]),dim = 2)   
                log_prob += prob.log()
            r_trans = r.to(self.device)
            if self.trainer_params['divide_std']:
                #r.shape=[B,G]
                advantage = ((r_trans-r_trans.mean(dim=1,keepdim=True))
                /(r_trans.std(dim=1, unbiased=False, keepdim=True)+1e-8)) if G!=1 else r_trans
            else:
                advantage = (r_trans-r_trans.mean(dim=1,keepdim=True)) if G!=1 else r_trans
            loss = (-advantage * log_prob).mean()
            self.optimizer.zero_grad() 
            loss.backward() 
            self.optimizer.step() 

            length_max = -r.max(dim=1)[0].mean().clone().detach().item()
            length_mean = -r.mean(1).mean().clone().detach().item()
            adv = advantage.abs().mean().clone().detach().item()
            episodes += 1
            if episodes % 500 == 0:
                self.logger.info(
                    "loss:{:3.4f},length_max:{:3.4f},length_mean:{:3.4f},adv:{:3.4f}"
                     .format(loss,length_max,length_mean,adv) 
                )  # sync_dist=True

            assert torch.isnan(loss).sum() == 0, print("loss is nan!")
            loss_onepoch.update(loss,B)
            score_onepoch.update(length_mean,B)
            rmax_onepoch.update(length_max,B)
            adv_onepoch.update(adv,B)


        if epoch % self.trainer_params['log_per_epoch'] == 0:
            self.train_lossmean.append(loss_onepoch.avg.item())
            self.train_scoremean.append(score_onepoch.avg)
            self.train_rmaxloss.append(rmax_onepoch.avg)
            self.train_advmean.append(adv_onepoch.avg)
            self.logger.info(
                "In Epoch {:3d}, loss:{:3.4f},length_mean:{:3.4f},length_max:{:3.4f},adv:{:3.4f}"
                 .format(epoch, loss_onepoch.avg,score_onepoch.avg,rmax_onepoch.avg,adv_onepoch.avg) 
            )  # sync_dist=True

        if epoch % self.trainer_params['saveckpt_per_epoch'] == 0:
            self.save_ckpt(epoch=epoch)

    def save_ckpt(self,epoch):
        self.logger.info("Saving trained_model")
        checkpoint_dict = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'result_log': self.result_log.get_raw_data()
        }
        torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))



DEBUG_MODE = False
USE_CUDA = not DEBUG_MODE
CUDA_DEVICE_NUM = 0
if __name__ == "__main__":
    from timer_logger import create_logger
    import logging
    env_params = Hyperparameters_Tunnel.env_params
    batch = Hyperparameters_Tunnel.running_params['eval_size']
    size = env_params['graph_size']
    tunnel = env_params['tunnel_per_graph']
    data_nodes = torch.rand(batch,size,2)
    data_tunnels = utils.generate_random_order(batch,size,2*tunnel)
    data_tunnels = data_tunnels.reshape(batch,tunnel,2)
    data_tunnels = torch.tensor(data_tunnels)

    USE_CUDA = Hyperparameters_Tunnel.USE_CUDA
    CUDA_DEVICE_NUM = Hyperparameters_Tunnel.CUDA_DEVICE_NUM

    lower_params = Hyperparameters_Tunnel.lower_params
    optimizer_params = Hyperparameters_Tunnel.optimizer_params
    running_params = Hyperparameters_Tunnel.running_params
    logger_params = Hyperparameters_Tunnel.logger_params
    create_logger(**logger_params)
    logger = logging.getLogger('root')
    logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM))
    [logger.info(g_key + "{}".format(globals()[g_key])) for g_key in globals().keys() if g_key.endswith('params')]
    device = torch.device("cuda" if torch.cuda.is_available() and running_params['use_cuda'] else "cpu")
    #_print_config()
    trainer = TunnelTSPrunning(env_params,lower_params,
                               running_params,optimizer_params,device)
    #print(data[0])
    reward = trainer(data_nodes,data_tunnels,return_pi = True)
    print(reward)
    trainer.start_train_tunnel()
    reward = trainer(data_nodes,data_tunnels,return_pi = True)
    print(reward)

