from __future__ import print_function
import os
import torch
from torch.utils.data import Dataset
import pdb


class MetaD2A(Dataset):
  def __init__(self, args, mode, use_447=True):
    ppath = '/w14/dataset/MetaD2A/predictor/'
    self.mode = mode
    self.acc_norm = args.acc_norm
    self.x = torch.load('/w14/dataset/MetaD2A/imgnet32bylabel')
    if mode == 'te':
      self.task = [1100, 4662, 516, 2089, 965, 4058, 3682, 3868, 3109, 1719, 
                            768, 3996, 232, 3193, 3545, 4976,  17,  3648,  2181,  1874]
      self.max_img = {'MetaD2A': args.max_img}
      self.task_lst = torch.load('/w14/dataset/MetaD2A/task_lst_vate.pth')
    elif mode in ['tr', 'va']:
      num_data = args.save_dir.split('/')[2].split('_')[1]
      data = torch.load(f'/w14/dataset/D2A/predictor/447_1232/rand_top_tr.pth') # 1214
      self.acc = data['vacc1']
      self.task = data['task']
      self.graph = data['arch_igraph']
      if use_447:
        data2 = torch.load(ppath+'task447_info.pt')
        self.acc += data2['acc']
        self.task += data2['task']
        self.graph += data2['graph']
        num_data = f'447_{num_data}'
      if not os.path.exists(ppath+f'metaD2A_idx.pt'): 
        ridx = torch.randperm(len(self.graph))
        torch.save(ridx, ppath+f'metaD2A_idx.pt')
      else:
        ridx = torch.load(ppath+f'metaD2A_idx.pt')

      self.vaidx = ridx[:400]
      self.tridx = ridx[400:]
      self.mean = torch.mean(torch.tensor(self.acc)[self.tridx]).item()
      self.std = torch.std(torch.tensor(self.acc)[self.tridx]).item()

      if mode == 'tr':
        self.max_img = {'MetaD2A': args.max_img}
        self.task_lst = torch.load('/w14/dataset/MetaD2A/task_lst_tr.pth')
      elif mode =='va':
        self.max_img = {'MetaD2A': args.max_img}
        self.task_lst = torch.load('/w14/dataset/MetaD2A/task_lst_tr.pth')


  def __len__(self):
    if self.mode == 'tr':
      return len(self.tridx)
    elif self.mode == 'va':
      return len(self.vaidx)
    elif self.mode in ['vate', 'te']:
      return len(self.task)


  def __getitem__(self, index):
    data = []
    if self.mode in ['te', 'vate']:
      tidx = self.task[index]
      classes = self.task_lst[tidx]
      for cls in classes:
        data.append(torch.stack([self.x[cls-1][0][_] for _
                                         in torch.randperm(len(self.x[cls-1][0]))[:self.max_img['MetaD2A']]]))
      x = torch.cat(data)
      return x
    else:
      ridx = self.tridx if self.mode == 'tr' else self.vaidx
      tidx = self.task[ridx[index]]
      classes = self.task_lst[tidx]
      graph = self.graph[ridx[index]]
      acc = self.acc[ridx[index]]

      for cls in classes:
        data.append(torch.stack([self.x[cls-1][0][_] for _
                                         in torch.randperm(len(self.x[cls-1][0]))[:self.max_img['MetaD2A']]]))
      x = torch.cat(data)
      if self.acc_norm:
        acc = ((acc- self.mean) / self.std) / 100.0
      else:
        acc = acc/ 100.0
      return x, graph, acc, 'MetaD2A'


class TestDataset(Dataset):
  def __init__(self, max_img=20, dname='cifar10'):
    self.max_img = {}
    self.max_img[dname] = max_img
    self.dname = dname
    self.num_cls = 10
    if self.dname =='cifar100':
      self.num_cls = 100
    elif self.dname == 'aircraft':
      self.num_cls = 30
    elif self.dname == 'pets':
      self.num_cls = 37
    self.x = torch.load('/w14/dataset/MetaD2A/{}bylabel'.format(dname))

  def __len__(self):
    return 1
  
  def __getitem__(self, index):
    data = []
    classes = list(range(self.num_cls))
    for cls in classes:
      data.append(torch.stack([self.x[cls][0][_] for _ in
                               torch.randperm(len(self.x[cls][0]))[:self.max_img[self.dname]]]))
    x = torch.cat(data)
    return x

  
if __name__ == '__main__':
  cub = Others()
  m=cub[0]
  pdb.set_trace()
  
  dataset = D2GDataset()
  print(len(dataset))

  for i in range(len(dataset)):
    t, g, a = dataset[i]
    pdb.set_trace()
