import json
import glob
import os
import torch as ch
import numpy as np
from tqdm import tqdm
from torchvision import transforms
from collections import namedtuple
import re

CONFIG_PATH = '/PATH/TO/CONFIGS/'
# CONFIG_PATH = './configs'cd

DS_PATH = '/PATH/TO/DATA/'
# DS_PATH = './data'

def dict2namedtuple(d): 
    Tuple = namedtuple('Tuple', d)
    return Tuple(**d)

def get_test(ds, class_name):
    return [
        f"{DS_PATH}/test/{ds}/{class_name}.npy"
    ]

def get_target(ds, class_name):
    return [
        f"{DS_PATH}/targets/{ds}/{class_name}/train.npy"
    ]

def get_classes(ds):
    if ds in ('pets',):
        return [
            "cat",
            "dog"
        ]
    elif ds in ('cifar10', '3db', 'in'):
        return [
            "airplane",
            "automobile",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck"
        ]
    elif ds in ('stl',):
        return [
            "airplane",
            "automobile",
            "bird",
            "cat",
            "deer",
            "dog",
            "horse",
            # "monkey",
            "ship",
            "truck"
        ]
    elif ds in ('sst',): 
        return [f'encoding_{i}' for i in range(2)]
    elif ds in ('emoji',): 
        return [f'encoding_{i}' for i in range(20)]

# small hack to ensure that mixed populations have at least 
# some data by over-sampling datasets and subsetting them
def mix_datasets(datasets, alphas, pop_size=100): 
    n = alphas.size(0) if alphas.size(0) > pop_size else pop_size
    ds = mix_datasets_(datasets, alphas, n)
    return shuffle_and_subset(ds, pop_size)

def mix_datasets_(datasets, alphas,pop_size=100): 
    counts = (alphas*pop_size).round().long()
    ds = []
    for i in range(alphas.size(0)):
        if counts[i] > 0: 
            ds.append(subsample(datasets[i],counts[i]))
    return ch.cat(ds,dim=0)

def subsample(x,n): 
    m = x.size(0)
    if n <= m: 
        p = ch.randperm(m)
        return x[p[:n]]
    else: 
        xs = []
        while n > m: 
            xs.append(x)
            n -= m
        p = ch.randperm(m)
        xs.append(x[p[:n]])
        return ch.cat(xs,dim=0)

def shuffle_and_subset(x, n): 
    P = ch.randperm(x.size(0))[:n]
    return x[P]

# faster version of load_files_all; no need to load 
# all files since the dataset proportions z will not change anymore. 
# do not use this if z is going to change
def load_files(prefix, files, z=None):
    if z is None: 
        return load_files_all(prefix,files)

    indices = z.nonzero().squeeze()
    datasets = []
    idx2fname = []
    fnames = []
    for fpattern in files: 
        sublist = glob.glob(os.path.join(prefix,fpattern))
        sublist.sort(key=lambda f: int(re.sub('\D', '', f)))
        fnames.extend(sublist)
    for i,fname in enumerate(tqdm(fnames)):
        # Load and shuffle tensor
        if i in indices: 
            try: 
                ds_np = np.load(fname,allow_pickle=True)
            except: 
                print(fname)
                raise ValueError(f"exception when loading file {fname}")
            t = ch.from_numpy(ds_np) 
            datasets.append(t)
        else: 
            datasets.append(None)
        idx2fname.append(fname)
    return datasets,idx2fname

def load_files_all(prefix,files): 
    datasets = []
    idx2fname = []
    fnames = []
    for fpattern in files: 
        sublist = glob.glob(os.path.join(prefix,fpattern))
        sublist.sort(key=lambda f: int(re.sub('\D', '', f)))
        fnames.extend(sublist)
    for fname in tqdm(fnames):
        # Load and shuffle tensor
        try: 
            ds_np = np.load(fname,allow_pickle=True)
        except: 
            print(fname)
            raise ValueError(f"exception when loading file {fname}")
        t = ch.from_numpy(ds_np) 
        datasets.append(t)
        idx2fname.append(fname)
    return datasets,idx2fname

# Load the config file
def load_config(name, nlp=False): 
    with open(f'{CONFIG_PATH}/{name}.json') as f:
        config = json.load(f)
    print(config)
    if nlp: 
        config["targets"] = [s.replace("encoding", "input") for s in config["targets"]]
        config["sources"] = [s.replace("encoding", "input") for s in config["sources"]]
    return dict2namedtuple(config)

# Scale dataset to [0, 1] if the dataset is STL
def adjust_if_stl(dataset, dataset_name):
    if dataset_name == 'stl':
        return dataset/255.
    return dataset

# Simple Data augmentation for training
TRAIN_TRANSFORM_SIMPLE = lambda size: transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

# Strong Data augmentation for training
TRAIN_TRANSFORM = transforms.Compose([
    transforms.ToPILImage(),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
])
# TRAIN_TRANSFORM = lambda size: transforms.Compose([
#             transforms.ToPILImage(),
#             transforms.RandomResizedCrop(size),
#             transforms.RandomHorizontalFlip(),
#             transforms.ColorJitter(.25,.25,.25),
#             transforms.RandomRotation(2),
#             transforms.ToTensor(),
#         ])

# Standard data augmentation for testing
TEST_TRANSFORM = lambda size: transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
        ])
