""" Evaluate various optimizers augmented with teleportation. """

import numpy as np
import time
from matplotlib import pyplot as plt
import pickle
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
import torch.utils.data as data

import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler
import os
from gradient_descent_utils import init_param_MLP, init_param_transformer, train_step, valid_MLP, valid_transformer, teleport, teleport_transformer
from models import MLP, Transformer, Transformer_LM, CNN, Resnet, Transformer_Classification_LM
from plot import plot_optimization
from tqdm import tqdm
from data.Adding_Problem.utils import AddingProblemDataset
from data.PennTree.utils import get_PennTree_dataset
from data.imdb.utils import get_imdb_dataset
from data.electricity.utils import get_electricity_dataset
from data.traffic.utils import get_traffic_dataset

from Custom_Optimizer import CustomAdam

device = 'cuda:0' #'cuda'
run_new = True # False if using cached results
dataset = 'traffic' # 'MNIST', 'FashionMNIST', 'CIFAR10'
# opt_method_list = ['Adam','Adagrad','Adamax','SGD', 'AdamW', 'momentum', 'RMSprop']
opt_method_list = ['Adam']

sigma = nn.ReLU()
batch_size = 32
valid_size = 0.2
tele_batch_size = 32
model_type = 'transformer'

# model_type = 'MLP'
# dataset and hyper-parameters
if model_type == 'MLP':
    if dataset == 'MNIST':
        epoch_num = 100
        criterion = nn.CrossEntropyLoss()
        lr = 2e-4
        dim = [28*28, 1024, 1024, 10]
        train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'FashionMNIST':
        epoch_num = 100
        criterion = nn.CrossEntropyLoss()
        lr = 2e-4
        dim = [28*28, 1024, 1024, 10]
        train_data = datasets.FashionMNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.FashionMNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'CIFAR10':
        epoch_num = 10
        criterion = nn.CrossEntropyLoss()
        lr = 2e-4
        dim = [batch_size, 32*32*3, 128, 32, 10]
        teledim = [tele_batch_size, 32*32*3, 128, 32, 10]
        train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    else:
        raise ValueError('dataset should be one of MNIST, fashion, and CIFAR10')

if model_type == 'CNN':
    if dataset == 'MNIST':
        epoch_num = 20
        criterion = nn.CrossEntropyLoss()
        lr = 2e-3
        in_channels = 1
        num_classes = 10
        train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'FashionMNIST':
        epoch_num = 20
        criterion = nn.CrossEntropyLoss()
        lr = 2e-3
        in_channels = 1
        num_classes = 10
        train_data = datasets.FashionMNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.FashionMNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'CIFAR10':
        epoch_num = 100
        criterion = nn.CrossEntropyLoss()
        lr = 2e-4
        in_channels = 3
        num_classes = 10
        train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'CIFAR100':
        epoch_num = 100
        criterion = nn.CrossEntropyLoss()
        lr = 2e-4
        in_channels = 3
        num_classes = 100
        train_data = datasets.CIFAR100(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.CIFAR100(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'Imagenet':
        epoch_num = 60
        criterion = nn.CrossEntropyLoss()
        lr = 5e-5
        in_channels = 3
        num_classes = 200
        data_dir = "./data/tiny_imagenet/tiny-224/"
        num_workers = {"train": 16, "val": 0, "test": 0}
        data_transforms = {
            "train": transforms.Compose(
                [
                    transforms.RandomRotation(20),
                    transforms.RandomHorizontalFlip(0.5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
                ]
            ),
            "val": transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
                ]
            ),
            "test": transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
                ]
            ),
        }
        image_datasets = {
            x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ["train", "val", "test"]
        }
        # dataloaders = {
        #     x: data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=num_workers[x])
        #     for x in ["train", "val", "test"]
        # }
        train_data = image_datasets["train"]
        test_data = image_datasets["val"]
    else:
        raise ValueError('dataset should be one of MNIST, fashion, and CIFAR10')
        
elif model_type == 'transformer':
    if dataset == 'MNIST':
        epoch_num = 400
        criterion = nn.CrossEntropyLoss()
        lr = 1e-4
        input_dim = 1
        d_model = 128
        num_heads = 2
        num_layers = 2
        dim_feedforward = 1024
        num_classes = 10
        seq_len = 28*28
        train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'FashionMNIST':
        epoch_num = 10
        criterion = nn.CrossEntropyLoss()
        lr = 1e-3
        input_dim = 1
        seq_len = 28*28
        model_dim = 64
        out_dim = 10
        layer_num = 1
        train_data = datasets.FashionMNIST(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.FashionMNIST(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'CIFAR10':
        epoch_num = 10
        criterion = nn.CrossEntropyLoss()
        lr = 2e-3
        input_dim = 3
        d_model = 128
        num_heads = 1
        num_layers = 4
        dim_feedforward = 128
        num_classes = 10
        seq_len = 32*32
        train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'CIFAR100':
        epoch_num = 10
        criterion = nn.CrossEntropyLoss()
        lr = 2e-3
        input_dim = 3
        d_model = 512
        num_heads = 8
        num_layers = 4
        dim_feedforward = 512
        num_classes = 100
        seq_len = 32*32
        train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transforms.ToTensor())
        test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transforms.ToTensor())
    elif dataset == 'adding':
        epoch_num = 20
        criterion = nn.MSELoss()
        lr = 5e-3
        input_dim = 2
        d_model = 64
        num_heads = 1
        num_layers = 2
        dim_feedforward = 64
        num_classes = 1
        seq_len = 64
        train_data = AddingProblemDataset(40000, seq_len)
        test_data = AddingProblemDataset(4000, seq_len)
    elif dataset == 'PennTree':
        epoch_num = 10
        batch_size = 32
        criterion = nn.CrossEntropyLoss()
        lr = 5e-5
        d_model = 256
        num_heads = 4
        num_layers = 4
        dim_feedforward = 1024
        seq_len = 256
        train_data, test_data, vocab_size = get_PennTree_dataset(seq_len)
    elif dataset == 'imdb':
        epoch_num = 100
        criterion = nn.CrossEntropyLoss()
        lr = 2e-3
        d_model = 128
        num_heads = 2
        num_layers = 2
        dim_feedforward = 1024
        seq_len = 256
        num_classes = 2
        train_data, test_data, vocab_size = get_imdb_dataset(seq_len)
    elif dataset == 'electricity':
        epoch_num = 50
        criterion = nn.MSELoss()
        lr = 1e-4
        input_dim = 321
        d_model = 256
        num_heads = 4
        num_layers = 4
        dim_feedforward = 1024
        seq_len = 168
        num_classes = 321
        train_data, test_data = get_electricity_dataset()
    elif dataset == 'traffic':
        epoch_num = 50
        criterion = nn.MSELoss()
        lr = 1e-4
        input_dim = 862
        d_model = 256
        num_heads = 4
        num_layers = 4
        dim_feedforward = 1024
        seq_len = 168
        num_classes = 862
        train_data, test_data = get_traffic_dataset()
    else:
        raise ValueError('dataset should be one of MNIST, fashion, and CIFAR10')
        

if dataset in ['MNIST', 'FashionMNIST']:
    train_subset, val_subset = torch.utils.data.random_split(
            train_data, [50000, 10000], generator=torch.Generator().manual_seed(1))
    train_loader = torch.utils.data.DataLoader(train_subset, batch_size = batch_size,
                                               shuffle = True, num_workers = 0)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size,
                                             num_workers = 0)
    teleport_loader = torch.utils.data.DataLoader(train_subset, batch_size = tele_batch_size,
                                               shuffle=True, num_workers = 0)
    teleport_loader_iterator = iter(teleport_loader)
    
elif dataset in ['CIFAR100','CIFAR10', 'adding', 'PennTree', "Imagenet", "imdb",'electricity','traffic']: 
    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size,
                                               shuffle = True, num_workers = 16)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size,
                                             num_workers = 0)
    teleport_loader = torch.utils.data.DataLoader(train_data, batch_size = tele_batch_size,
                                               shuffle=True, num_workers = 0)
    teleport_loader_iterator = iter(teleport_loader)
    print(len(train_loader.sampler))

def get_optimizer(model, opt_method, lr, dataset):
    if opt_method == 'SGD':
        return torch.optim.SGD(model.parameters(), lr=lr)
    elif opt_method == 'Adagrad':
        return torch.optim.Adagrad(model.parameters(), lr=lr)
    elif opt_method == 'momentum':
        return torch.optim.SGD(model.parameters(), lr=lr/1e1, momentum=0.9)
    elif opt_method == 'RMSprop':
        return torch.optim.RMSprop(model.parameters(), lr=lr/1e2)
    elif opt_method == 'Adam':
        return torch.optim.Adam(model.parameters(), lr=lr/1e2)
        # return torch.optim.Adam(model.parameters(), lr=lr)
    elif opt_method == 'Adamax':
        return torch.optim.Adam(model.parameters(), lr=lr/1e2)
        # return torch.optim.Adamax(model.parameters(), lr=lr, betas=(0.9, 0.999))
    elif opt_method == 'CustomAdam':
        return CustomAdam(model.parameters(), lr=lr/1e2)
        # return CustomAdam(model.parameters(), lr=lr)
    elif opt_method == 'AdamW':
        return torch.optim.AdamW(model.parameters(), lr=lr/1e2)
        # return torch.optim.AdamW(model.parameters(), lr=lr)
    else:
        raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam')


start_epoch = 15
end_epoch = 40

if run_new == True:
    for opt_method in opt_method_list:   
        if model_type == 'MLP':
            if opt_method == 'SGD':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 2e-1
                tele_step = 8
            elif opt_method == 'Adagrad':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 2e-1
                tele_step = 8
            elif opt_method == 'momentum':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 2e-1
                tele_step = 8
            elif opt_method == 'RMSprop':
                tele_epochs = [1]
                tele_lr = 2e-1
                tele_step = 8
            elif opt_method == 'Adam':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 2e-1
                tele_step = 8
            elif opt_method == 'Adamax':
                tele_epochs = [1]
                tele_lr = 2e-1
                tele_step = 8
            elif opt_method == 'CustomAdam':
                tele_epochs = [1]
                tele_lr = 2e-1
                tele_step = 8
            elif opt_method == 'AdamW':
                tele_epochs = [1]
                tele_lr = 2e-1
                tele_step = 8
            else:
                raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam')
        elif model_type == 'transformer' and dataset in ['MNIST','CIFAR100', 'CIFAR10', 'electricity','traffic']:
            if opt_method == 'SGD':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'Adagrad':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'momentum':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'RMSprop':
                tele_epochs = [1]
                tele_lr = 5e-3
                tele_step = 8
            elif opt_method == 'Adam':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'Adamax':
                tele_epochs = [1]
                tele_lr = 5e-3
                tele_step = 8
            elif opt_method == 'CustomAdam':
                tele_epochs = [1]
                tele_lr = 5e-3
                tele_step = 8
            elif opt_method == 'AdamW':
                tele_epochs = [1]
                tele_lr = 5e-3
                tele_step = 8
            else:
                raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam')
        elif model_type == 'transformer' and dataset == 'adding':
            if opt_method == 'SGD':
                tele_epochs = [1]
                tele_lr = 1e-5
                tele_step = 8
            elif opt_method == 'Adagrad':
                tele_epochs = [1]
                tele_lr = 1e-5
                tele_step = 8
            elif opt_method == 'momentum':
                tele_epochs = [1]
                tele_lr = 1e-5
                tele_step = 8
            elif opt_method == 'RMSprop':
                tele_epochs = [1]
                tele_lr = 1e-5
                tele_step = 8
            elif opt_method == 'Adam':
                tele_epochs = [1]
                tele_lr = 1e-5
                tele_step = 8
            elif opt_method == 'AdamW':
                tele_epochs = [1]
                tele_lr = 1e-5
                tele_step = 8
            else:
                raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam')
        elif model_type == 'transformer' and dataset in ['PennTree', "imdb"]:
            if opt_method == 'SGD':
                tele_epochs = [1,2,3]
                tele_lr = 5e-4
                tele_step = 8
            elif opt_method == 'Adagrad':
                tele_epochs = [1,2,3]
                tele_lr = 5e-4
                tele_step = 8
            elif opt_method == 'momentum':
                tele_epochs = [1,2,3]
                tele_lr = 5e-4
                tele_step = 8
            elif opt_method == 'RMSprop':
                tele_epochs = [1]
                tele_lr = 5e-2
                tele_step = 8
            elif opt_method == 'Adam':
                tele_epochs = [1,2,3]
                tele_lr = 5e-4
                tele_step = 8
            elif opt_method == 'Adamax':
                tele_epochs = [1]
                tele_lr = 5e-3
                tele_step = 8
            elif opt_method == 'CustomAdam':
                tele_epochs = [1]
                tele_lr = 5e-3
                tele_step = 8
            elif opt_method == 'AdamW':
                tele_epochs = [1]
                tele_lr = 5e-3
                tele_step = 8
            else:
                raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam')

        elif model_type == 'CNN':
            if opt_method == 'SGD':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'Adagrad':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'momentum':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'RMSprop':
                tele_epochs = [0,1,2]
                tele_lr = 5e-5
                tele_step = 8
            elif opt_method == 'Adam':
                tele_epochs = [0,1,2,3,4]
                tele_lr = 3e-3
                tele_step = 8
            elif opt_method == 'Adamax':
                tele_epochs = [0,1,2]
                tele_lr = 5e-4
                tele_step = 8
            elif opt_method == 'CustomAdam':
                tele_epochs = [0]
                tele_lr = 5e-4
                tele_step = 8
            elif opt_method == 'AdamW':
                tele_epochs = [0,1,2]
                tele_lr = 5e-4
                tele_step = 8
            else:
                raise ValueError('opt_method should be one of SGD, AdaGrad, momentum, RMSProp, and Adam')

        for run_num in range(3):
            print(opt_method, 'run', run_num)
            
            ##############################################################
            # training with opt_method without teleportation (e.g. AdaGrad)
            loss_arr_SGD = []
            dL_dt_arr_SGD = []
            valid_loss_SGD = []
            valid_correct_SGD = []
            time_SGD = []

            if model_type == 'MLP':
                model = MLP(dim, activation=sigma, seed = run_num**2 + run_num)
            elif model_type == 'CNN' and dataset != 'Imagenet':
                model = CNN(in_channels, num_classes, activation=sigma, seed = run_num**2 + run_num)
            elif model_type == 'CNN' and dataset == 'Imagenet':
                model = Resnet(in_channels, num_classes, activation=sigma, seed = run_num**2 + run_num)
            elif model_type == 'transformer' and dataset in ['electricity','traffic']:
                model = Transformer(input_dim, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            elif model_type == 'transformer' and dataset == 'imdb':
                model = Transformer_Classification_LM(vocab_size, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            elif model_type == 'transformer' and dataset != 'PennTree':
                model = Transformer(input_dim, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            else:
                model = Transformer_LM(vocab_size, d_model, num_heads, num_layers, dim_feedforward, vocab_size, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            model.to(device)
            optimizer = get_optimizer(model, opt_method, lr, dataset)

            
            t0 = time.time()
            for epoch in tqdm(range(epoch_num)):
                epoch_loss = 0.0
                model.train()
                for idx, data in tqdm(enumerate(train_loader)):
                    if len(data) == 2:
                        data, label = data
                    else:
                        data, attention_mask, label = data['input_ids'], data['attention_mask'], data['targets']
                    batch_size = data.shape[0]
                    if model_type == 'MLP':
                        data = data.view(batch_size, -1).to(device) # [20, 1, 28, 28] -> [784, 20]
                    elif model_type == 'CNN':
                        data = data.to(device) # [20, 1, 28, 28] -> [20,28,28,1]
                    elif model_type == 'transformer' and dataset in ['electricity','traffic']:
                        data = data.to(device)
                    elif model_type == 'transformer' and dataset in ['imdb']:
                        data = data.to(device)
                        attention_mask = attention_mask.to(device)
                    elif model_type == 'transformer' and dataset != 'PennTree':
                        data = data.view(batch_size, input_dim, -1).transpose(1,2).to(device) # [20, 1, 28, 28] -> [20, 784, 1]
                    elif model_type == 'transformer' and dataset == 'PennTree':
                        data = data.to(device) # [batch_size*seq_len]
                    label = label.to(device)
                    if model_type == 'transformer' and dataset == 'PennTree':
                        label = label.reshape(-1)
                    if dataset == 'imdb':
                        outs, _ = model(data, attention_mask)
                    elif not model_type == 'CNN':
                        outs, _ = model(data)
                    else:
                        outs, _,_,_ = model(data)
                    if dataset == 'PennTree':
                        loss = criterion(outs.reshape(-1,outs.shape[-1]), label)
                    else:
                        loss = criterion(outs, label)
                    epoch_loss += loss.item() * batch_size
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    
                loss_arr_SGD.append(epoch_loss / len(train_loader.sampler))
                print(f'train loss {epoch_loss / len(train_loader.sampler)}')
                

                model.eval()
                valid_correct = 0
                valid_loss = 0
                with torch.no_grad():
                    for data in tqdm(test_loader):
                        if len(data) == 2:
                            data, label = data
                        else:
                            data, attention_mask, label = data['input_ids'], data['attention_mask'], data['targets']
                        batch_size = data.shape[0]
                        if model_type == 'MLP':
                            data = data.view(batch_size, -1).to(device) # [20, 1, 28, 28] -> [784, 20]
                        elif model_type == 'CNN':
                            data = data.to(device) # [20, 1, 28, 28] -> [20,28,28,1]
                        elif model_type == 'transformer' and dataset in ['electricity','traffic']:
                            data = data.to(device)
                        elif model_type == 'transformer' and dataset in ['imdb']:
                            data = data.to(device)
                            attention_mask = attention_mask.to(device)
                        elif model_type == 'transformer' and dataset != 'PennTree':
                            data = data.view(batch_size, input_dim, -1).transpose(1,2).to(device) # [20, 1, 28, 28] -> [20, 784, 1]
                        elif model_type == 'transformer' and dataset == 'PennTree':
                            data = data.to(device) # [batch_size*seq_len]
                        label = label.to(device)
                        if model_type == 'transformer' and dataset == 'PennTree':
                            label = label.reshape(-1)
                        if dataset == 'imdb':
                            outs, _ = model(data, attention_mask)
                        elif not model_type == 'CNN':
                            outs, _ = model(data)
                        else:
                            outs, _,_,_ = model(data)
                        if dataset == 'PennTree':
                            loss = criterion(outs.reshape(-1,outs.shape[-1]), label)
                        else:
                            loss = criterion(outs, label)
                        valid_loss += loss.item() * batch_size
                        
                #         _, pred = torch.max(outs, 1)
                #         valid_correct += pred.eq(label.data.view_as(pred)).sum().item()
                # print(100.0*valid_correct/len(test_loader.sampler))
                
                valid_loss_SGD.append(valid_loss/ len(test_loader.sampler))
                print(f'test loss {valid_loss / len(test_loader.sampler)}')
                # valid_correct_SGD.append(100.0*valid_correct/len(test_loader.sampler))

                # print(epoch, loss_arr_SGD[-1], valid_loss_SGD[-1], valid_correct_SGD[-1])

                t1 = time.time()
                time_SGD.append(t1 - t0)

            results = (loss_arr_SGD, valid_loss_SGD, dL_dt_arr_SGD, time_SGD, 0)
            # results = (loss_arr_SGD, valid_loss_SGD, dL_dt_arr_SGD, valid_correct_SGD, time_SGD, 0)
            with open('logs/optimization_final/{}/{}_{}_lr_{}_{}_{}.pkl'.format(dataset, dataset, opt_method, lr, model_type, run_num), 'wb') as f:
                pickle.dump(results, f)
            

            ##############################################################
            # training with opt_method + teleport
            loss_arr_teleport = []
            dL_dt_arr_teleport = []
            valid_loss_teleport = []
            valid_correct_teleport = []
            time_teleport = []

            if model_type == 'MLP':
                model = MLP(dim, activation=sigma, seed = run_num**2 + run_num)
            elif model_type == 'CNN' and dataset != 'Imagenet':
                model = CNN(in_channels, num_classes, activation=sigma, seed = run_num**2 + run_num)
            elif model_type == 'CNN' and dataset == 'Imagenet':
                model = Resnet(in_channels, num_classes, activation=sigma, seed = run_num**2 + run_num)
            elif model_type == 'transformer' and dataset in ['electricity','traffic']:
                model = Transformer(input_dim, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            elif model_type == 'transformer' and dataset == 'imdb':
                model = Transformer_Classification_LM(vocab_size, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            elif model_type == 'transformer' and dataset != 'PennTree':
                model = Transformer(input_dim, d_model, num_heads, num_layers, dim_feedforward, num_classes, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            else:
                model = Transformer_LM(vocab_size, d_model, num_heads, num_layers, dim_feedforward, vocab_size, activation=sigma, max_len=seq_len, seed = run_num**2 + run_num)
            model.to(device)
            optimizer = get_optimizer(model, opt_method, lr, dataset)

            
            t0 = time.time()

            for epoch in tqdm(range(epoch_num)):
                teleport_count = 0
                epoch_loss = 0.0

                while (epoch in tele_epochs and teleport_count < 32): 
                    teleport_count += 1

                    # load data batch
                    try:
                        if dataset != 'imdb':
                            tele_data, tele_target = next(teleport_loader_iterator)
                            tele_data, tele_target = tele_data.to(device), tele_target.to(device)
                        else:
                            data = next(teleport_loader_iterator)
                            tele_data, tele_attention_mask, tele_target = data['input_ids'], data['attention_mask'], data['targets']
                            tele_data, tele_attention_mask, tele_target = tele_data.to(device), tele_attention_mask.to(device),tele_target.to(device)


                    
                    except StopIteration:
                        teleport_loader_iterator = iter(teleport_loader)
                        if dataset != 'imdb':
                            tele_data, tele_target = next(teleport_loader_iterator)
                            tele_data, tele_target = tele_data.to(device), tele_target.to(device)
                        else:
                            data = next(teleport_loader_iterator)
                            tele_data, tele_attention_mask, tele_target = data['input_ids'], data['attention_mask'], data['targets']
                            tele_data, tele_attention_mask, tele_target = tele_data.to(device), tele_attention_mask.to(device),tele_target.to(device)
                        
                    # teleport
                    batch_size_tele = tele_data.shape[0]
                    if model_type == 'MLP':
                        tele_data = tele_data.view(batch_size_tele, -1).to(device) # [20, 1, 28, 28] -> [784, 20]
                    if model_type == 'CNN':
                        tele_data = tele_data.to(device)
                    elif model_type == 'transformer' and dataset in ['electricity','traffic']:
                        tele_data = tele_data.to(device)
                    elif model_type == 'transformer' and dataset in ['imdb']:
                        tele_data = tele_data.to(device)
                        tele_attention_mask = tele_attention_mask.to(device)
                    elif model_type == 'transformer' and dataset != 'PennTree':
                        tele_data = tele_data.view(batch_size_tele, input_dim, -1).transpose(1,2).to(device)
                    elif model_type == 'transformer' and dataset == 'PennTree':
                        tele_data = tele_data.to(device)
                    if model_type == 'MLP':
                        model = teleport(model, tele_data, tele_target, criterion, tele_lr, sigma, dL_dt_cap=5, telestep=tele_step, random_teleport=False, reverse=False)
                    elif model_type == 'CNN':
                        model = teleport(model, tele_data, tele_target, criterion, tele_lr, sigma, dL_dt_cap=40, telestep=tele_step, random_teleport=False, reverse=False, CNN = True)
                    elif model_type == 'transformer' and dataset == 'imdb':   
                        model = teleport(model, tele_data, tele_target, criterion, tele_lr, sigma, dL_dt_cap=20, telestep=tele_step, random_teleport=False, reverse=False, attention_mask = tele_attention_mask)
                    elif model_type == 'transformer' and dataset != 'PennTree':   
                        model = teleport(model, tele_data, tele_target, criterion, tele_lr, sigma, dL_dt_cap=10, telestep=tele_step, random_teleport=False, reverse=False)
                    elif model_type == 'transformer' and dataset == 'PennTree':  
                        model = teleport(model, tele_data, tele_target, criterion, tele_lr, sigma, dL_dt_cap=5, telestep=tele_step, random_teleport=False, reverse=False, PennTree = True)
                    # optimizer_state_dict = optimizer.state_dict()
                    optimizer = get_optimizer(model, opt_method, lr, dataset)
                    # optimizer.load_state_dict(optimizer_state_dict)      
                    
                model.train()
                for idx, data in tqdm(enumerate(train_loader)):
                    if len(data) == 2:
                        data, label = data
                    else:
                        data, attention_mask, label = data['input_ids'], data['attention_mask'], data['targets']
                    batch_size = data.shape[0]
                    if model_type == 'MLP':
                        data = data.view(batch_size, -1).to(device) # [20, 1, 28, 28] -> [784, 20]
                    if model_type == 'CNN':
                        data = data.to(device) 
                    elif model_type == 'transformer' and dataset in ['electricity','traffic']:
                        data = data.to(device)
                    elif model_type == 'transformer' and dataset == 'imdb':
                        data = data.to(device)
                        attention_mask = attention_mask.to(device)
                    elif model_type == 'transformer' and dataset != 'PennTree':
                        data = data.view(batch_size, input_dim, -1).transpose(1,2).to(device) # [20, 1, 28, 28] -> [20, 784, 1]
                    elif model_type == 'transformer' and dataset == 'PennTree':
                        data = data.to(device) # [batch_size*seq_len]
                    label = label.to(device)
                    if model_type == 'transformer' and dataset == 'PennTree':
                        label = label.reshape(-1)
                    if dataset == 'imdb':
                        outs, _ = model(data, attention_mask)
                    elif not model_type == 'CNN':
                        outs, _ = model(data)
                    else:
                        outs, _,_,_ = model(data)
                    if dataset == 'PennTree':
                        loss = criterion(outs.reshape(-1,outs.shape[-1]), label)
                    else:
                        loss = criterion(outs, label)
                    epoch_loss += loss.item() * batch_size
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                loss_arr_teleport.append(epoch_loss / len(train_loader.sampler))
                print(epoch_loss / len(train_loader.sampler))
                

                
                model.eval()
                valid_correct = 0
                valid_loss = 0
                with torch.no_grad():
                    for data in tqdm(test_loader):
                        if len(data) == 2:
                            data, label = data
                        else:
                            data, attention_mask, label = data['input_ids'], data['attention_mask'], data['targets']
                        batch_size = data.shape[0]
                        if model_type == 'MLP':
                            data = data.view(batch_size, -1).to(device) # [20, 1, 28, 28] -> [784, 20]
                        if model_type == 'CNN':
                            data = data.to(device) # [20, 1, 28, 28] -> [20,28,28,1]
                        elif model_type == 'transformer' and dataset in ['electricity','traffic']:
                            data = data.to(device)
                        elif model_type == 'transformer' and dataset == 'imdb':
                            data = data.to(device)
                            attention_mask = attention_mask.to(device)
                        elif model_type == 'transformer' and dataset != 'PennTree':
                            data = data.view(batch_size, input_dim, -1).transpose(1,2).to(device) # [20, 1, 28, 28] -> [20, 784, 1]
                        elif model_type == 'transformer' and dataset == 'PennTree':
                            data = data.to(device) # [batch_size*seq_len]
                        label = label.to(device)
                        if model_type == 'transformer' and dataset == 'PennTree':
                            label = label.reshape(-1)
                        if dataset == 'imdb':
                            outs, _ = model(data, attention_mask)
                        elif not model_type == 'CNN':
                            outs, _ = model(data)
                        else:
                            outs, _,_,_ = model(data)
                        if dataset == 'PennTree':
                            loss = criterion(outs.reshape(-1,outs.shape[-1]), label)
                        else:
                            loss = criterion(outs, label)
                        valid_loss += loss.item() * batch_size
                        # _, pred = torch.max(outs, 1)
                        # valid_correct += pred.eq(label.data.view_as(pred)).sum().item()
                valid_loss_teleport.append(valid_loss/ len(test_loader.sampler))
                # valid_correct_teleport.append(100.0*valid_correct/len(test_loader.sampler))

                # print(epoch, loss_arr_teleport[-1], valid_loss_teleport[-1], valid_correct_teleport[-1])

                t1 = time.time()
                time_teleport.append(t1 - t0)

            # results = (loss_arr_teleport, valid_loss_teleport, dL_dt_arr_teleport, valid_correct_teleport, time_teleport, 0)
            results = (loss_arr_teleport, valid_loss_teleport, dL_dt_arr_teleport, time_teleport, 0)
            with open('logs/optimization_final/{}/{}_{}_lr_{}_{}_teleport_{}.pkl'.format(dataset, dataset, opt_method, lr, model_type, run_num), 'wb') as f:
                pickle.dump(results, f)

plot_optimization(opt_method_list, dataset, lr, model_type, epoch_num)
