import argparse
import torch as ch
import torch.nn as nn
from torchvision import transforms
import numpy as np
import os
import collections
from torch.utils.data import TensorDataset, DataLoader, Dataset
from pathlib import Path

# helpers for class names, filepaths, and loading datasets
from helpers import get_target, get_test, get_classes, load_config, load_files, shuffle_and_subset, mix_datasets, adjust_if_stl
# helpers for training
from helpers import TRAIN_TRANSFORM, TRAIN_TRANSFORM_SIMPLE, TEST_TRANSFORM

# imports for train/eval
from cox import store
from cox.utils import Parameters
from robustness import defaults
from robustness.train import train_model, eval_model
from robustness.datasets import CIFAR
from robustness.model_utils import make_and_restore_model

# language model
from transformers import BertForSequenceClassification


# For applying transforms to numpy-array-based datasets in PyTorch
# From: https://stackoverflow.com/questions/55588201/pytorch-transforms-on-tensordataset
class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)


# convert a list of per-class datasets to a dataloader
def data2loader(data, is_train=True, use_strong_da=True, nlp=False, **kwargs):
    if nlp: 
        transform = transforms.Compose([])
    elif is_train and use_strong_da:
        transform = TRAIN_TRANSFORM
    elif is_train and (not use_strong_da):
        transform = TRAIN_TRANSFORM_SIMPLE(32)
    else:
        transform = TEST_TRANSFORM(32)

    data = [(d, (i*ch.ones(d.size(0))).long()) for i,d in enumerate(data)]
    X,y = list(zip(*data))
    X,y = ch.cat(X), ch.cat(y)
    ds = CustomTensorDataset(tensors=(X,y), transform=transform)
    dl = DataLoader(ds, **kwargs)
    return dl

# class LinearWrapper(nn.Module): 
#     def __init__(self, inf, outf): 
#         super(LinearWrapper, self).__init__()
#         self.net = nn.Sequential(
#             nn.Linear(inf, 256), 
#             nn.ReLU(), 
#             nn.Linear(256,outf)
#         )
#     def forward(self, x, **kwargs): 
#         return self.net(x),None
class Wrapper(nn.Module): 
    def __init__(self, net): 
        super(Wrapper, self).__init__()
        self.net = net
    def forward(self, x, **kwargs): 
        input_ids, attention_mask = x.long().split(512,1)
        output = self.net(input_ids=input_ids, attention_mask=attention_mask).logits
        return output,None

# Function to train and save a model
def train(run_num, train_loader_name, train_loader, val_loader, arch, dataset, epochs=100, nlp=False):
    loaders = (train_loader, val_loader)
    
    if arch == 'mlp': 
        model = Wrapper(BertForSequenceClassification.from_pretrained("bert-base-cased",num_labels=dataset.num_classes))
        # model =LinearWrapper(768,dataset.num_classes)
    else: 
        model, _ = make_and_restore_model(arch=arch,
            dataset=dataset, pytorch_pretrained=False)
    
    args = {
    'out_dir': OUT_DIR,
    'adv_train': 0,
    'adv_eval': 0,
    'epochs': epochs,
    'config_path': None,
    'dataset': dataset,
    'arch': arch,
    }
    if nlp: 
        args['lr'] = 1e-5

    args = Parameters(args)
    args = defaults.check_and_fill_args(args, defaults.TRAINING_ARGS, CIFAR)    
    
    my_store = store.Store(OUT_DIR, f'{train_loader_name}_{run_num}')
    args_dict = args.as_dict() if isinstance(args, Parameters) else vars(args)
    schema = store.schema_from_dict(args_dict)
    my_store.add_table('metadata', schema)
    my_store['metadata'].append_row(args_dict)

    model = train_model(args, model, loaders, checkpoint=None, store=my_store)
    my_store.close()
    
# Get the "Projected away" portion of the dataset, given the projected portion.
# If nothing is "Projected away" (that is, every subset has nonzero weight), then
# just default back to the rand_fixed baseline.
def get_projected_away(z):
    nonzero_idx = ch.where(z!=0)
    zero_idx = ch.where(z==0)
    num_nonzero = len(zero_idx[0])
    x = ch.ones_like(z)
    if num_nonzero > 0:
        x[nonzero_idx] = 0
        x = x/(num_nonzero)
    else:
        x = x/z.size(0)
    return x


if __name__ == "__main__": 
    parser = argparse.ArgumentParser(description='Process cluster at a threshold.')
    parser.add_argument('source')
    parser.add_argument('target')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--eps', type=float, default=1)
    parser.add_argument('--out-dir', type=str, default='./trained_models/')
    parser.add_argument('--res-dir', type=str, default='./results/')
    # parser.add_argument('--trainconfig', type=int)
    parser.add_argument('--pgd', action='store_true')
    parser.add_argument('--nlp', action='store_true') # experimental setup for nlp, which sets target classes and model to fit
    parser.set_defaults(pgd=False)

    args = parser.parse_args()

    ch.manual_seed(args.seed)

    eps = args.eps
    # trainconfig = args.trainconfig
    if eps == 1.0:
        eps = int(1)
    # print(f"Source: {args.source} | Target: {args.target} | Eps: {eps} | Trainconfig: {trainconfig}")
    print(f"Source: {args.source} | Target: {args.target} | Eps: {eps}")

    print("Creating real data loaders")
    tgt_classes = get_classes(args.target)
    src_classes = get_classes(args.source)
    if args.nlp: 
        classes = tgt_classes
        pop_size = 100
        classes_npy = [np.load(get_test(args.target, c.replace("encoding","input"))[0]) for c in classes]
    else: 
        classes = [c for c in tgt_classes if c in src_classes]
        pop_size = 1000
        classes_npy = [np.load(get_test(args.target, c)[0]) for c in classes]
    min_class_sz = min(c_npy.shape[0] for c_npy in classes_npy)
    print(f"class size: {min_class_sz}, num_classes: {len(classes)}")
    # Make the test loader
    kwargs = {"batch_size": 8 if args.nlp else 32, "shuffle": False, "nlp": args.nlp}
    test_ld = data2loader([shuffle_and_subset(ch.from_numpy(c_npy), min_class_sz) for c_npy in classes_npy], is_train=False, **kwargs)
    # assert False

    if args.pgd:
        projected_dataset = 'pgd'
    else:
        projected_dataset = 'active_set'


    # Load the 3 different datasets under question
    data = {
        projected_dataset: [], 
        f'{projected_dataset}_projaway': [], 
        'rand_fixed': [], 
        'real': []
    }

    for class_name in classes: 
        print(f'loading projected datasets for {class_name}')
        config_name = f'{args.source}_{args.target}_class_{class_name}_eps_{eps}'
        config = load_config(config_name, args.nlp)
        result = os.path.join(args.res_dir,config_name,"mmd.pth")
        
        d = ch.load(result)

        # Below, z corresponds to the dataset proportions; that is, how much of each subpopulation should we sample?
        # get active set result
        z = d['z']
        ds, _ = load_files(config.base_dir, config.sources, z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z, pop_size=pop_size), args.source)
        # Make sure we aren't accidentally loading data like STL in the range 0->255 anymore, make sure everything is in the range 0->1
        # Don't want some weird sampling to result in super small dataset; actually, I checked this, and it's mostly okay. 
        assert(mixed_ds.max() <= 1.01) or args.nlp
        data[projected_dataset].append(mixed_ds)

        # get projected away result
        # print(projected_dataset)
        # print(z)
        z_projaway = get_projected_away(z)
        # print('projected away portion')
        # print(z_projaway)
        ds, _ = load_files(config.base_dir, config.sources, z_projaway)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z_projaway, pop_size=pop_size), args.source)
        assert(mixed_ds.max() <= 1.01) or args.nlp
        data[f'{projected_dataset}_projaway'].append(mixed_ds)

        # get random result
        z = ch.ones_like(z)
        z = z/(z.size(0))
        ds, _ = load_files(config.base_dir, config.sources, z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z, pop_size=pop_size), args.source)
        # Make sure we aren't accidentally loading data like STL in the range 0->255 anymore, make sure everything is in the range 0->1
        assert(mixed_ds.max() <= 1.01) or args.nlp
        data['rand_fixed'].append(mixed_ds)

        # get real result
        z = ch.Tensor(config.target_dist)
        ds, _ = load_files(config.base_dir, config.targets, z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z, pop_size=pop_size), args.target)
        # Make sure we aren't accidentally loading data like STL in the range 0->255 anymore, make sure everything is in the range 0->1
        assert(mixed_ds.max() <= 1.01) or args.nlp
        data['real'].append(mixed_ds)

    print("Creating remaining training data loaders")
    kwargs = {"batch_size": 8 if args.nlp else 32, "shuffle": True, "nlp": args.nlp}
    loaders = {}
    for k in data:
        loaders[k] = data2loader(data[k], is_train=True, use_strong_da=False, **kwargs)
        # loaders[f'{k}_da'] = data2loader(data[k], is_train=True, use_strong_da=True, **kwargs)
        loaders[f'{k}_autoaugment_da'] = data2loader(data[k], is_train=True, use_strong_da=True, **kwargs)
    
    # add hybrid loaders - that is, loaders that use both the real data and the Active Set Projected/Random/Projected-Away datasets
    for k in [projected_dataset, 'rand_fixed', f'{projected_dataset}_projaway']: 
        hybrid_loader = data2loader([ch.cat([d1,d2],dim=0) for d1,d2 in zip(data['real'], data[k])], is_train=True, use_strong_da=False, **kwargs)
        loaders[f'hybrid_{k}'] = hybrid_loader
        hybrid_loader = data2loader([ch.cat([d1,d2],dim=0) for d1,d2 in zip(data['real'], data[k])], is_train=True, use_strong_da=True, **kwargs)
        # loaders[f'hybrid_{k}_da'] = hybrid_loader
        loaders[f'hybrid_{k}_autoaugment_da'] = hybrid_loader

    all_loaders = [ (k,loaders[k]) for k in loaders ]
    all_loaders = collections.OrderedDict(all_loaders)


    # dummy dataset placeholder for using the robustness library
    # The images we use are 32x32, so we apply normalization that corresponds to CIFAR images (even if the dataset is not CIFAR)
    if args.nlp: 
        arch = 'mlp'
        t = transforms.Compose([transforms.ToTensor()])
        MY_DS = CIFAR(num_classes = len(tgt_classes), transform_train=t, transform_test=t)
        epochs = 100

    else: 
        arch = 'resnet18'
        MY_DS = CIFAR(num_classes = len(tgt_classes))
        epochs = 200
    OUT_DIR = os.path.join(f'{args.out_dir}',f'{args.source}_{args.target}_{eps}')
    print(f"Saving trained models in: {OUT_DIR}")
    Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

    # Repeat 5 times to get error bars
    for i in range(0, 5):
        for loader_name, loader in all_loaders.items():
            if loader_name in ['real', 'real_da', 'real_autoaugment_da'] and eps != 1:
                print(f"Skipping {loader_name}, we already did this once for eps=1, dont need to redo for different epsilons")
                continue
            if args.pgd and loader_name in ['rand_fixed', 'hybrid_rand_fixed', 'rand_fixed_autoaugment_da', 'hybrid_rand_fixed_autoaugment_da', 'real', 'real_autoaugment_da']:
                print(f"Skipping {loader_name}, we are just doing PGD, so we already did the random baselines")
                continue
            # THIS IS HARDCODED
            # if loader_name in ['real', 'real_da'] and eps != 1:
            #     print(f"Skipping {loader_name}, we already did this once for eps=1, dont need to redo for different epsilons")
            #     continue
            # if args.pgd and loader_name in ['rand_fixed', 'hybrid_rand_fixed']:
            #     print(f"Skipping {loader_name}, we are just doing PGD, so we already did the random baselines")
            #     continue
            if os.path.exists(os.path.join(OUT_DIR, f"{loader_name}_{i}",f"{epochs}_checkpoint.pt")): 
                print(f"Skipping {loader_name}_{i}, already done")
                continue
            if "_da" in loader_name: 
                print(f"skipping {loader_name}_{i}, language task")
                continue
            if "hybrid" not in loader_name: 
                print("skipping on hybrid just to check")
                continue
            print(f'Now training on {loader_name}')
            for item in test_ld:
                break
            train(i, loader_name, loader, test_ld, arch, MY_DS, nlp=args.nlp, epochs=epochs)