# %reload_ext autoreload
# %autoreload 2
# %matplotlib inline


import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.datasets.vision import VisionDataset
from torch.optim.optimizer import Optimizer, required
import pickle
from PIL import Image
import math
from pprint import pprint as prt
import functools,itertools
import time
from time import time as timer
import os
import cv2
# from meta_module import *
import sys




DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def argsPrinter(args,n_epochs,epoch_len,unroll,Target_DataGen,OPT,preproc,hidden_layers_zer,pixels,batch_size,cifarAB,num_workers,n_tests,eva_epoch_len,pin_memory,Target_Optimizee,**w):
    
    # task_desc: used for wIns dir name, trained .pth file name, SR database name
    _l2o_net = OPT(**args)
    if args['mode'] == 'tras':
        task_desc = f'train_scratch {_l2o_net.name} @ {Target_Optimizee().name}'
    elif args['mode'] == 'tune':
        task_desc = f'tune {_l2o_net.name} @ {Target_Optimizee().name}'
    elif args['mode'] == 'nort':
        task_desc = f'normal_train resnet18'
    elif args['mode'] in ['pure', 'srck', 'srgen', 'rec', 'tradi']:
        task_desc = f"eva({args['mode']}), {_l2o_net.name} @ {Target_Optimizee().name}"
    else:
        raise ValueError

    args['task_desc'] = task_desc 

    s=f'''

        # ============ GENERAL ============
        'pid':                  {os.getpid()}
        'task_desc':            {task_desc}




        # ============ Training from scratch ============
        'n_epochs':             {n_epochs}
        'epoch_len':            {epoch_len}
        'unroll':               {unroll}

        'batch_size':           {batch_size}
        'num_workers':          {num_workers}
        'pin_memory':           {pin_memory}


        'l2o_net.name'          {_l2o_net.name}
        'Ngf'(num_grad_feat):   {args.get('Ngf')}
        'preproc':              {preproc}
        'hidden_layers_zer':    {hidden_layers_zer}




        # ============ Evaluation ============
        'n_tests':              {n_tests}
        'eva_epoch_len':        {eva_epoch_len}

        'l2o_net.name'          {_l2o_net.name}




        # ============ Which Problem ============

        'Target_DataGen':       {Target_DataGen}
        'Target_Optimizee':     {Target_Optimizee}
        'pixels'(784/3072):     {pixels}




        # ============ SR Data Gen ============
        'want_SR':              {args.get('want_SR')}
        'n_epochs':             {args.get('n_epochs_tuneSR')}
        'epoch_len':            {args.get('epoch_len_tuneSR')}
        'num_Xyitems_SR':       {args.get('num_Xyitems_SR')}
        'SR_memlen':            {args.get('SR_memlen')}



        # ============ Tune SR ============
        'n_epochs':             {args.get('n_epochs_tuneSR')}
        'epoch_len':            {args.get('epoch_len_tuneSR')}
        'unroll':               {args.get('unroll_tuneSR')}

        'lr_meta_tuneSR':       {args.get('lr_meta_tuneSR')}
        'OPT_META_SR':          {args.get('OPT_META_SR')}


    '''
    print(s)


    return
























# ============================================
# ================ Problems ==================
# ============================================


class Cifar_f:
    def __init__(self, training=True, shuffle_cifar=True,num_workers=2,batch_size=None,**kwargs):


        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


        dataset = torchvision.datasets.CIFAR10(root='./datasets', train=bool(training), download=False, transform=transform_train)



        self.infIter = get_infinite_iter(dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True, **kwargs)





        testset = torchvision.datasets.CIFAR10(root='./datasets', train=False,
                                               download=False, transform=transform_test)
        self.testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                                 shuffle=True, num_workers=num_workers)



    def sample(self):
        return self.infIter.sample()



# class SMALL_CNN(MetaModule):
#     name = 'SMALL_CNN'
#     # def __init__(self, pixels = [28*28, 3*32*32][1], hidden_layers_zee=[('conv', 3,6,3), ('conv', 6, 16, 3), ('fc', 16*5*5, 120), ('fc', 120, 84), ('fc', 84, 5),], activation=[nn.Sigmoid, nn.ReLU][1], **kwargs):
#     def __init__(self, pixels = [28*28, 3*32*32][1], hidden_layers_zee=[('conv', 3,6,3), ('conv', 6, 12, 3), ('fc', 12*6*6, 10),], activation=[nn.Sigmoid, nn.ReLU][1], **kwargs):
#         # used for cifar-AB, where output has 5 labels

#         super().__init__()
#         self.hidden_layers_zee = hidden_layers_zee
#         self.N_layers = len(hidden_layers_zee)
#         self.activation = activation()

#         self.layers = {}
#         for il in range(len(hidden_layers_zee)):
#             if hidden_layers_zee[il][0]=='conv':
#                 self.layers[f'conv_{il}'] = MetaConv2d(*hidden_layers_zee[il][1:]).float()
#             elif hidden_layers_zee[il][0]=='fc':
#                 self.layers[f'fc_{il}'] = MetaLinear(*hidden_layers_zee[il][1:]).float()
#             else:
#                 raise NotImplementedError
#         self.layers = nn.ModuleDict(self.layers)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.loss = nn.CrossEntropyLoss()



#     def fwdPass(self, x):
#         for il in range(self.N_layers):
#             if self.hidden_layers_zee[il][0]=='conv':
#                 x = self.pool(self.activation(self.layers[f'conv_{il}'](x)))
#             elif self.hidden_layers_zee[il][0]=='fc':
#                 if self.hidden_layers_zee[il-1][0]=='conv':
#                     # print(x.shape)
#                     x = x.view(-1, self.hidden_layers_zee[il][1])
#                 x = self.layers[f'fc_{il}'](x)
#                 if il<self.N_layers-1:
#                     x = self.activation(x)
#         return x


#     def forward(self, loss):

#         x, label = loss.sample()

#         pred = self.fwdPass(x)

#         l = self.loss(pred, label)

#         return l




#     def cal_test_acc(self, testloader):

#         correct = 0
#         total = 0
#         with torch.no_grad():
#             for i,data in enumerate(testloader):
#                 # print('i+',i)
#                 images, labels = data
#                 images, labels = images.to(DEVICE), labels.to(DEVICE)

#                 outputs = self.fwdPass(images)

#                 # outputs = net(images)
#                 _, predicted = torch.max(outputs.data, 1)
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum().item()

        
#         acc = 100 * correct / total
    
#         print(f'\n\n\n  Test acc is: {acc}\n\n\n')
#         return acc




























def get_infinite_iter(dataset, batch_size=1,num_workers=2,shuffle=True,sampler=None, pin_memory=None, **args):
    if sampler is None:
        if shuffle:
            sampler = torch.utils.data.sampler.RandomSampler(list(range(len(dataset))))
        else:
            sampler = torch.utils.data.sampler.SequentialSampler(list(range(len(dataset))))
    
    dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            sampler = sampler,
            # shuffle=True,
            num_workers=num_workers,
            persistent_workers = True if num_workers>0 else False,            
            pin_memory=pin_memory,
            # collate_fn=dataset.collate_fn,
        )
    
    dataloaderIter = dataloader.__iter__()
   
    class InfIter(dataloaderIter.__class__):
        def __init__(self,loader):
            super().__init__(loader)
            self.loader = loader
        def __next__(self):

            try:
                return super().__next__()

            except StopIteration:
                self._reset(self.loader)
                return super().__next__()

        def sample(self):
            data_batch, label_batch = self.__next__()
            return data_batch.to(DEVICE), label_batch.to(DEVICE)

    infIter = InfIter(dataloader)
    return infIter



class MNISTLoss:
    def __init__(self, training=True, **kwargs):
        trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (1.0,))])
        dataset = datasets.MNIST(
            './datasets', train=True, download=False,
            transform=trans)
        indices = list(range(len(dataset)))
        np.random.RandomState(10).shuffle(indices)
        if training:
            indices = indices[:len(indices) // 2]
        else:
            indices = indices[len(indices) // 2:]

        self.loader = torch.utils.data.DataLoader(
            dataset, batch_size=128,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))

        self.batches = []
        self.cur_batch = 0
        
    def sample(self):
        if self.cur_batch >= len(self.batches):
            print('\nnew sample!!!\n')
            self.batches = []
            self.cur_batch = 0
            for b in self.loader:
                self.batches.append(b)
            print('\nnew sample ,  DONE. !!!\n')
        batch = self.batches[self.cur_batch]
        self.cur_batch += 1
        # print(batch.shape)
        return batch






class Cifar_half:
    def __init__(self, cifarAB='A', training=True, shuffle_cifar=True,num_workers=2,batch_size=128,**kwargs):

        cifar_roots  = {'A': 'datasets/cifar10-A',
                       'B': 'datasets/cifar10-B'}
        root = cifar_roots[cifarAB]
        self.cifarAB = cifarAB

        transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = CIFAR10_halfLabel(root=root, train=bool(training), transform=transform)
        self.loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=bool(shuffle_cifar), num_workers=num_workers)

        self.batches = []
        self.cur_batch = 0


    def sample(self):
        if self.cur_batch >= len(self.batches):
            print('\nnew sample!!!\n')
            self.batches = []
            self.cur_batch = 0
            for b in self.loader:
                self.batches.append(b)
            print('\nnew sample ,  DONE. !!!\n')
        batch = self.batches[self.cur_batch]
        self.cur_batch += 1
        return batch
















class MNISTLoss_f:
    def __init__(self, training=True, num_workers=2, batch_size=None, **kwargs):
        trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (1.0,))])

        dataset = datasets.MNIST(
            './datasets', train=True, download=False,
            transform=trans)


        indices = list(range(len(dataset))) # len = 60000
        np.random.RandomState(10).shuffle(indices)
        if training:
            indices = indices[:len(indices) // 2]
        else:
            indices = indices[len(indices) // 2:]

        self.infIter = get_infinite_iter(dataset,batch_size=batch_size,num_workers=num_workers,sampler=torch.utils.data.sampler.SubsetRandomSampler(indices), **kwargs)

    def sample(self):

        return self.infIter.sample()


class Cifar_half_f:
    def __init__(self, cifarAB='A', training=True, shuffle_cifar=True,num_workers=2,batch_size=None,**kwargs):

        cifar_roots  = {'A': 'datasets/cifar10-A',
                       'B': 'datasets/cifar10-B'}
        root = cifar_roots[cifarAB]
        self.cifarAB = cifarAB

        transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = CIFAR10_halfLabel(root=root, train=bool(training), transform=transform)

        self.infIter = get_infinite_iter(dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True, **kwargs)


    def sample(self):
        return self.infIter.sample()





# class MLP_MNIST(MetaModule):
#     name = 'MLP_MNIST'
#     def __init__(self, pixels = [28*28, 3*32*32][0], hidden_layers_zee=[[30,20,], [50,20]][0], activation=[nn.Sigmoid, nn.ReLU][1], **kwargs):
#         super().__init__()
#         self.hidden_layers_zee = hidden_layers_zee
#         self.pixels = pixels
        
#         self.layers = {}
#         for i in range(len(hidden_layers_zee)):
#             self.layers[f'mat_{i}'] = MetaLinear(pixels, hidden_layers_zee[i]).float()
#             pixels = hidden_layers_zee[i]

#         self.layers['final_mat'] = MetaLinear(pixels, 10).float()
#         self.layers = nn.ModuleDict(self.layers)

#         # print(f'optimizee 111 .device={self.layers.device}')


#         self.activation = activation()
#         self.loss = nn.NLLLoss()

#     # def all_named_parameters(self):
#     #     return [(k, v) for k, v in self.named_parameters()]


#     def forward(self, loss):
#         inp, out = loss.sample()
#         # print(inp.shape)
#         inp = Variable(inp.view(inp.size()[0], self.pixels)).to(DEVICE)
#         out = Variable(out).to(DEVICE)

#         cur_layer = 0
#         while f'mat_{cur_layer}' in self.layers:
#             inp = self.activation(self.layers[f'mat_{cur_layer}'](inp))
#             cur_layer += 1

#         inp = F.log_softmax(self.layers['final_mat'](inp), dim=1)
#         l = self.loss(inp, out)
#         return l


# class MLP_MNIST2(MLP_MNIST):
#     name = 'MLP_MNIST2'
#     def __init__(self, *args, **kwargs):
#         super().__init__(hidden_layers_zee=[50,20,20,12,], *args, **kwargs)
























































# =========================================
# ================ Utils ==================
# =========================================



def proj_batch(img):
    # img = img[:,:,:10,:10].type(torch.DoubleTensor)
    img = img.type(torch.DoubleTensor)
    B,C,H,W = img.shape
    N = H*W
    if type(img) is np.ndarray:
        pass
    else:
        # vh_constVec = torch.ones(H,W,device = DEVICE)/np.sqrt(N)
        ph = torch.sum(img,(2,3))/np.sqrt(N)

        pim = img-ph.view(B,C,1,1)/np.sqrt(N)
        # pim = img-torch.sum(img,(2,3),keepdim=True)/N

    # # =========== TODO: below seems to be buggy: output shouldnot exceed -1~1
    # print(min(pim.reshape(-1)))
    # print(max(pim.reshape(-1)))
        
    # print(min(img.reshape(-1)))
    # print(max(img.reshape(-1)))

    return pim, ph


def viz_batch_img(images_BCHW, lbstr = ''):
    imgg = torchvision.utils.make_grid(images_BCHW)
    imgg = imgg / 2 + 0.5     # unnormalize
    npimg = imgg.numpy()
    # print(np.transpose(npimg, (1, 2, 0)).shape)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    print(f'labels are: < {lbstr} >')
    return








def playground_projection():
    AB = 0


    batch_size = 4
    root = ['datasets/cifar10-A', 'datasets/cifar10-B'][AB]
    classes= ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    classesA= ('plane', 'bird', 'deer', 'frog', 'ship', )
    classesB= ('car', 'cat', 'dog', 'horse', 'truck')

    transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])





    trainset = CIFAR10_halfLabel(root=root, train=False, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
    dataiter = iter(trainloader)
    images, labels = dataiter.next()



    viz_batch_img(images, lbstr = f'{[[classesA,classesB][AB][labels[j]] for j in range(batch_size)]}')



    proj_ims, ph = proj_batch(images)

    viz_batch_img(proj_ims, lbstr = f'{[[classesA,classesB][AB][labels[j]] for j in range(batch_size)]}')






    imgn = images.numpy() # BCHW, [-1,1]




    gamma = [0.04, 25][0]

    lookUpTable = np.empty((1,256), np.uint8)
    for i in range(256):
        lookUpTable[0,i] = np.clip(pow(i / 255.0, gamma) * 255.0, 0, 255)





    ims_gamma = gamma_lightness_on_11(images,3)
    proj_ims_gamma, _ = proj_batch(ims_gamma)



    viz_batch_img(ims_gamma, lbstr = f'{[[classesA,classesB][AB][labels[j]] for j in range(batch_size)]}')
    viz_batch_img(proj_ims_gamma, lbstr = f'{[[classesA,classesB][AB][labels[j]] for j in range(batch_size)]}')



    ims_gamma = gamma_lightness_on_11(images,0.3)
    proj_ims_gamma, _ = proj_batch(ims_gamma)



    viz_batch_img(ims_gamma, lbstr = f'{[[classesA,classesB][AB][labels[j]] for j in range(batch_size)]}')
    viz_batch_img(proj_ims_gamma, lbstr = f'{[[classesA,classesB][AB][labels[j]] for j in range(batch_size)]}')















last_time = time.time()
begin_time = last_time

def progress_bar(current, total, msg=None):
    _, term_width = os.popen('stty size', 'r').read().split()
    term_width = int(term_width)
    TOTAL_BAR_LENGTH = 65.
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    # L.append('  Step: %s' % format_time(step_time))
    # L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f










def eva(args, l2o_net=None, eva_epoch_len=None, n_tests=None, want_save_eva_loss=None, **kwargs):
    
    argsPrinter(args,**args)
    res = [run_epoch(args, eva_epoch_len, should_train=False, **args) for _ in tqdm(range(n_tests))]
    losses_1epoch_Nrun, sr_r_t_n_f = zip(*res)
    # sr_t_n_f shape: [run_epoch_len, N_optimizee_param_groups, 1+l2o_net.Ngf+1]
    sr_r_t_n_f = np.asarray(sr_r_t_n_f) # [n_tests, sr_t_n_f]


    all_losses = np.asarray(losses_1epoch_Nrun)
    avg_eva_sum_loss = np.mean(np.sum(all_losses,1),0)
    avg_eva_last_loss = np.mean(all_losses[:,-1],0)

    figDir, recDir = wzRec(all_losses, f'Eva-loss', args['task_desc'], want_save=want_save_eva_loss)
    print(f'\n=======================\nEva result figure saved to:\n< {figDir} >\n=======================\n')
    if want_save_eva_loss:
        print(f'All loss/acc records saved to:\n< {recDir} >\n=======================\n')

    return figDir, avg_eva_sum_loss, sr_r_t_n_f







def viz(net, ttl='', print_device=False):
    viz_ = []
    for name, p in net.named_parameters():
        if print_device:
            viz_.append((name, p.device, list(p.size())))
        else:
            viz_.append((name, list(p.size())))
    # print(f'\nparams of: {ttl}\nlength = {len(viz_)}\n{viz_}')
    print(f'\nparams of: {ttl}\nN_groups = {len(viz_)}')
    prt(viz_)
    return







def viz_dataset(AB):
    batch_size = 4
    root = ['datasets/cifar10-A', 'datasets/cifar10-B'][AB]
    classes= ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    classesA= ('plane', 'bird', 'deer', 'frog', 'ship', )
    classesB= ('car', 'cat', 'dog', 'horse', 'truck')

    trainset = CIFAR10_halfLabel(root=root, train=False, transform= transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
    dataiter = iter(trainloader)
    images, labels = dataiter.next()


    print(images.shape)
    print(labels)
    
    lbstr = ' '.join('%5s' % [classesA,classesB][AB][labels[j]] for j in range(batch_size))

    vizImgs(images,labels=lbstr,vrange='11')

    # imshowcifar(torchvision.utils.make_grid(images), lbstr)





# def set_seed(seed):
#     import random
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed(seed)
#         torch.cuda.manual_seed_all(seed)
#         # os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_num)
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     random.seed(seed)





class CIFAR10_halfLabel(VisionDataset):
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]
    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    def __init__(
            self,
            root: str,
            train: bool = True,
            transform = None,
            target_transform = None,
    ) -> None:
        super(CIFAR10_halfLabel, self).__init__(root, transform=transform,
                                      target_transform=target_transform)
        self.train = train  # training set or test set
        self.data: Any = []
        self.targets = []
        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry[b'data'])
                self.targets.extend(entry[b'labels'])
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

    def __getitem__(self, index: int):
        img, target = self.data[index], self.targets[index]//2
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self) -> int:
        return len(self.data)







def wzRec(datas, ttl='', task_desc='newTask', want_save=False):

    if want_save: 
        os.makedirs(f'wIns/Recs',exist_ok=1)
        recDir = f'wIns/Recs/{ttl}-{task_desc}.npy'
        np.save(recDir, datas)
    else:
        recDir = None

    datas = np.asarray(datas)
    plt.close('all')
    plt.figure()
    if len(datas.shape)==1:
        min_v = min(datas)
        plt.plot(datas)
        plt.title(ttl+f', min = {min_v:5.4f}\n'+task_desc)
        plt.xlabel('step')
    elif len(datas.shape)==2:
        min_s = np.min(datas, axis=1)
        mean_min = np.mean(min_s)
        std_min = np.std(min_s)
        min_str = f'(avg={mean_min:5.4f},std={std_min:5.4f})'
        plot_ci(datas, ttl=ttl+f', min = {min_str}\n'+task_desc, xlb='step')
    else:
        raise ValueError('dim should be 1D or 2D')

    lt2 = time.strftime("%Y-%m-%d--%H_%M_%S", time.localtime())
    figDir = f'wIns/{ttl}_{task_desc}_{lt2}.jpg'
    plt.savefig(figDir)

    return figDir, recDir




def wIni(net):
    dic = net.state_dict()
    dic2 = {}
    std = 0.01
    for k,v in dic.items():
        dic2[k] = torch.randn(*v.shape, device=DEVICE)*std
    net.load_state_dict(dic2)



def updateDict_skip_running(net,dic):
    newDic = {}
    for k,v in net.state_dict().items():
        if 'running' in k:
            newDic[k]=v
    dic.update(newDic)
    net.load_state_dict(dic)
    return


# def MLP(layers, activation=nn.ReLU):
#     # layers 是从输入到hidden到输出全部n+1个int
#     nn_list = []
#     for i in range(len(layers)-1):
#         nn_list.extend([nn.Linear(layers[i], layers[i+1]), activation(), ])
#     return nn.Sequential(*nn_list)


def getMLP(neurons, activation=nn.ReLU, bias=True):
    # neurons: all n+1 dims from input to output
    # len(neurons) = n+1
    # num of params layers = n
    # num of activations = n-1
    nn_list = []
    n = len(neurons)-1
    for i in range(n-1):
        nn_list.extend([nn.Linear(neurons[i], neurons[i+1], bias=bias), activation(), ])
    nn_list.append(nn.Linear(neurons[n-1], neurons[n], bias=bias))
    return nn.Sequential(*nn_list)




def load_model(net, cwd, verbose=True, strict=True):
    def load_torch_file(network, cwd):
        network_dict = torch.load(cwd, map_location=lambda storage, loc: storage)
        network.load_state_dict(network_dict, strict=strict)
    if os.path.exists(cwd):
        load_torch_file(net, cwd)
        if verbose: print("\nLOAD success! from :", cwd)
    else:
        if verbose: print("\n\n\n  !!! FileNotFound when load_model: {}".format(cwd))


def save_model(net, cwd):  # 2020-05-20
    torch.save(net.state_dict(), cwd)
    print("\nSaved @ :", cwd)




def plot_ci(arr, vx=[], is_std=True, ttl='', xlb='',ylb='',semilogy=False, viz_un_log=False):
    arr = np.asarray(arr)
    if len(arr.shape)==1:  arr = arr.reshape(1,-1)
    rdcolor = plt.get_cmap('viridis')(np.random.rand())  # 随机颜色

    mean = np.mean(arr,axis=0)
    if is_std:
        ci = np.std(arr,axis=0)
        lowci = mean-ci*is_std
        hici = mean+ci*is_std
    else:
        lowci = np.min(arr,axis=0)
        hici = np.max(arr,axis=0)
    # plt.plot(mean, color = '#539caf')
    if viz_un_log:
        mean=np.exp(mean)
        lowci=np.exp(lowci)
        hici=np.exp(hici)
    if vx == []:
        vx_=np.arange(len(mean))
    if semilogy:
        plt.semilogy(vx_, mean, color = rdcolor)
    else:
        plt.plot(vx_, mean, color = rdcolor)
    plt.fill_between(vx_, lowci, hici, color = rdcolor, alpha = 0.4)
    plt.xlabel(xlb)
    plt.ylabel(ylb)
    if list(vx): plt.xticks(vx)
    plt.title(ttl)
    return
       

def bestGPU(gpu_verbose=True, **w):

    import GPUtil
    Gpus = GPUtil.getGPUs()
    Ngpu = 4
    mems, loads = [], []
    for ig, gpu in enumerate(Gpus):
        memUtil = gpu.memoryUtil*100
        load = gpu.load*100
        mems.append(memUtil)
        loads.append(load)
        if gpu_verbose: print(f'gpu-{ig}:   Memory: {memUtil:.2f}%   |   load: {load:.2f}% ')
    bestMem = np.argmin(mems)
    bestLoad = np.argmin(loads)
    best = bestMem
    if gpu_verbose: print(f'//////   Will Use GPU - {best}  //////')
    # print(type(best))

    return int(best)
























