import argparse
import torch as ch
import torch.nn as nn
from torchvision import transforms
import numpy as np
import os
import collections
from torch.utils.data import TensorDataset, DataLoader, Dataset
from pathlib import Path

# helpers for class names, filepaths, and loading datasets
from helpers import get_target, get_test, get_classes, load_config, load_files, shuffle_and_subset, mix_datasets, adjust_if_stl
# helpers for training
from helpers import TRAIN_TRANSFORM, TRAIN_TRANSFORM_SIMPLE, TEST_TRANSFORM

# imports for train/eval
from cox import store
from cox.utils import Parameters
from robustness import defaults
from robustness.train import train_model, eval_model
from robustness.datasets import CIFAR
from robustness.model_utils import make_and_restore_model

# language model
from transformers import TrainingArguments, Trainer
from transformers import BertForSequenceClassification
from datasets import load_metric, concatenate_datasets, Value
import nlp_datasets as nd

def finetune(run_num, train_loader_name, tr_ds, te_ds, out_dir, nclasses): 
    # n_classes = len(te_ds.unique('label'))
    # print(n_classes, te_ds.unique('label'))
    # print(nclasses)
    # assert False
    model = BertForSequenceClassification.from_pretrained("bert-base-cased",num_labels=nclasses)
    model.config.problem_type = "single_label_classification"
    
    metric = load_metric("accuracy")
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    output_dir = os.path.join(out_dir, f'{train_loader_name}_{run_num}')
    training_args = TrainingArguments(output_dir=output_dir, evaluation_strategy="epoch", seed=run_num)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tr_ds,
        eval_dataset=te_ds,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.save_state()
    
# Get the "Projected away" portion of the dataset, given the projected portion.
# If nothing is "Projected away" (that is, every subset has nonzero weight), then
# just default back to the rand_fixed baseline.
def get_projected_away(z):
    nonzero_idx = ch.where(z!=0)
    zero_idx = ch.where(z==0)
    num_nonzero = len(zero_idx[0])
    x = ch.ones_like(z)
    if num_nonzero > 0:
        x[nonzero_idx] = 0
        x = x/(num_nonzero)
    else:
        x = x/z.size(0)
    return x

def balance_datasets(datasets, pop_size, seed): 
    classes_to_keep = [i for i in range(len(datasets['targets'])) if all(len(datasets[k][i])>=0.75*pop_size for k in datasets)]
    for k in ['targets', 'test']: 
        datasets[k] = [ds.shuffle(seed=seed).select(range(min(pop_size,len(ds)))) for ds in datasets[k]]
    return datasets, classes_to_keep

def process_ds(ds, pop_size, seed, translate=False): 
    ds = ds.shuffle(seed=seed).select(range(pop_size))
    if translate: 
        ds = nd.nlp_augment(ds)
    ds = ds.map(nd.generic_tokenizer)
    cols_to_keep = ['input_ids', 'token_type_ids', 'attention_mask', 'label', 'text']
    cols_to_remove = [col for col in ds.column_names if col not in cols_to_keep]
    ds = ds.map(lambda example: {k: example[k] for k in cols_to_keep}, remove_columns=cols_to_remove)
    # ds = ds.remove_columns(cols_to_remove) # for some reason this is bugged and does not update the format fingerprint
    # ds.format['columns'] = cols_to_keep # for some reason remove columns doesn't update the format columns
    return ds

if __name__ == "__main__": 
    parser = argparse.ArgumentParser(description='Process cluster at a threshold.')
    parser.add_argument('source')
    parser.add_argument('target')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--eps', type=float, default=1)
    parser.add_argument('--out-dir', type=str, default='./trained_models/')
    parser.add_argument('--res-dir', type=str, default='./results/')
    # parser.add_argument('--trainconfig', type=int)
    parser.add_argument('--pgd', action='store_true')
    parser.add_argument('--nlp', action='store_true') # experimental setup for nlp, which sets target classes and model to fit
    parser.set_defaults(pgd=False)

    args = parser.parse_args()

    ch.manual_seed(args.seed)

    eps = args.eps
    # trainconfig = args.trainconfig
    if eps == 1.0 or eps == 10.0:
        eps = int(eps)
    # print(f"Source: {args.source} | Target: {args.target} | Eps: {eps} | Trainconfig: {trainconfig}")
    print(f"Source: {args.source} | Target: {args.target} | Eps: {eps}")

    tgt_ds = nd.create_datasets[args.target]()
    src_ds = nd.create_datasets[args.source]()
    aug_ds = nd.create_datasets[args.target]()

    pop_size = 100
    nclasses = len(tgt_ds["targets"])

    if args.target == 'dydae': 
        tgt_ds['targets'] = [tgt_ds['targets'][i] for i in [3,4,6]]
        tgt_ds['test'] = [tgt_ds['test'][i] for i in [3,4,6]]
        aug_ds['targets'] = [aug_ds['targets'][i] for i in [3,4,6]]

    # if args.target == 'dydae': 
    tgt_ds,c2k = balance_datasets(tgt_ds, pop_size, args.seed)
    aug_ds,_ = balance_datasets(tgt_ds, pop_size, args.seed)
    src_ds,_ = balance_datasets(src_ds, pop_size, args.seed)

    tr_ds = process_ds(concatenate_datasets(tgt_ds['targets']), pop_size=pop_size, seed=args.seed)
    te_ds = process_ds(concatenate_datasets(tgt_ds['test']), pop_size=pop_size, seed=args.seed)
    aug_ds = process_ds(concatenate_datasets(aug_ds['targets']), pop_size=pop_size, seed=args.seed, translate=True)

    if args.pgd:
        projected_dataset = 'pgd'
    else:
        projected_dataset = f'active_set_eps_{eps}'


    # Load the 3 different datasets under question
    data = {
        projected_dataset: [], 
        f'{projected_dataset}_projaway': [], 
        'rand_fixed': [], 
        'real': tr_ds
    }
    ds_list = []
    class_idxs = [3,4,6] if args.target == 'dydae' else list(range(nclasses))
    # print("Classes: ", c2k)
    for class_name in c2k: 
    # for class_name in class_idxs: 
        print(f'loading projected datasets for class {class_name}')
        config_name = f'{args.source}_{args.target}_class_encoding_{class_name}_eps_{eps}'
        result = os.path.join(args.res_dir,config_name,"mmd.pth")
        
        d = ch.load(result)

        def set_to_i(e): 
            e['label'] = class_name
            return e

        # Below, z corresponds to the dataset proportions; that is, how much of each subpopulation should we sample?
        # get active set result
        z = d['z']
        ds = nd.mix_datasets(src_ds['sources'], z, pop_size=pop_size)
        ds = ds.map(set_to_i)
        ds = ds.cast_column("label", Value(dtype='int32', id=None))
        data[projected_dataset].append(ds)

        # get projected away result
        print(projected_dataset)
        print(z)
        z_projaway = get_projected_away(z)
        print('projected away portion')
        print(z_projaway)
        ds = nd.mix_datasets(src_ds['sources'], z_projaway, pop_size=pop_size)
        ds = ds.map(set_to_i)
        ds = ds.cast_column("label", Value(dtype='int32', id=None))
        data[f'{projected_dataset}_projaway'].append(ds)

        # get random result
        z = ch.ones_like(z)
        z = z/(z.size(0))
        ds = nd.mix_datasets(src_ds['sources'], z, pop_size=pop_size)
        ds = ds.map(set_to_i)
        ds = ds.cast_column("label", Value(dtype='int32', id=None))
        data['rand_fixed'].append(ds)

    for s in [projected_dataset, f'{projected_dataset}_projaway', 'rand_fixed']: 
        data[s] = process_ds(concatenate_datasets(data[s]), pop_size=pop_size, seed=args.seed)
    
    # add hybrid loaders - that is, loaders that use both the real data and the Active Set Projected/Random/Projected-Away datasets
    for k in [projected_dataset, 'rand_fixed', f'{projected_dataset}_projaway']: 
        data[f'hybrid_{k}'] = concatenate_datasets([data[k],data['real'].cast(data[k].features)])
        data[f'hybrid_aug_{k}'] = concatenate_datasets([aug_ds,data[k],data['real'].cast(data[k].features)])
    data[f'hybrid_aug'] = concatenate_datasets([aug_ds,data['real'].cast(aug_ds.features)])

    all_datasets = [ (k,data[k]) for k in data ]
    all_datasets = collections.OrderedDict(all_datasets)

    # Repeat 5 times to get error bars
    for i in range(5):
        for dataset_name, dataset in all_datasets.items():
            # if loader_name in ['real', 'real_da']:
            #     print(f"Skipping {loader_name}, we already did this once, dont need to redo for different epsilons")
            #     continue
            if args.pgd and 'pgd' not in dataset_name:
                print(f"Skipping {dataset_name}, we are just doing PGD")
                continue
            if eps != 1 and 'active_set' not in dataset_name:
                print(f"Skipping {dataset_name}, we are just doing active set epsilons")
                continue
            if os.path.exists(os.path.join(args.out_dir, f"{dataset_name}_{i}","trainer_state.json")): 
                print(f"Skipping {dataset_name}_{i}, already done")
                continue
            print(f'Now training on {dataset_name}')
            finetune(i, dataset_name, dataset, te_ds, args.out_dir, nclasses)
