import torch
from .datasets import KneeOA, WaterBirds, FoodReview
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
import time
import os
import argparse
import itertools
from .utils import evaluate_teacher, train_teacher_model
from sklearn.model_selection import KFold
from . import const
from transformers import BertTokenizer


# Train on GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using the following device: " + str(device))

# Detects any issues with back propogation
torch.autograd.set_detect_anomaly(True)

# Where to get data
WB_DATASET_DIR = const.WB_DATASET_DIR
KOA_DATASET_DIR = const.KOA_DATASET_DIR
FR_DATASET_DIR = const.FR_DATASET_DIR

# Where to save results
CROSS_VAL_RESULTS_DIRECTORY = const.CROSS_VAL_RESULTS_DIRECTORY

# Number of cross validation folds
NUM_FOLDS = const.NUM_FOLDS

# Number of workers
NUM_WORKERS = const.NUM_WORKERS


# Generate command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
                    type=str,
                    help='Dataset to use -- can be \"koa\", \"waterbirds\", or \"food_review\"')
parser.add_argument('--datasets_seeds',
                    type=int,
                    nargs='*',
                    default=[],
                    help='Number of datasets to perform cross validation over.')
parser.add_argument('--batch_size',
                    type=int,
                    help='Batch size for all models')
parser.add_argument('--num_epochs',
                    type=int,
                    help='Number of epochs for all models')
parser.add_argument('--teacher_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the teacher')
parser.add_argument('--teacher_l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the teacher')
parser.add_argument('--teacher_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the teacher')
parser.add_argument('--num_samples',
                    type=int,
                    nargs='?',
                    help='Size of data to use in training')  

# Parse command line arguments
args = parser.parse_args()
dataset = args.dataset
datasets_seeds = args.datasets_seeds
batch_size = args.batch_size
num_epochs = args.num_epochs
teacher_lr = list(args.teacher_lr)
teacher_l2_cost = list(args.teacher_l2_cost)
teacher_dropout = list(args.teacher_dropout)
num_samples = args.num_samples


# Create dataframe to store the results and iterator of all hyperparameters to evaluate
columns = ['teacher_lr', 'teacher_l2_cost', 'teacher_dropout']
hp_iterator = {'teacher_lr': teacher_lr, 'teacher_l2_cost': teacher_l2_cost, 'teacher_dropout': teacher_dropout}

# Keep track of results
columns.append('training_set')
columns.append('score')

# Add all columns to dataframe
results_df = pd.DataFrame(columns = columns)



# Create hyperparameter iterator
keys, values = zip(*hp_iterator.items())
hp_iterator = [dict(zip(keys, v)) for v in itertools.product(*values)]

# Iterate over multiple different simulations
for random_seed in tqdm(datasets_seeds, position=0, desc='Random Seeds'):

    # Iterate over multiple different simulations
    for hp_args in tqdm(hp_iterator, position=0, desc='Analysis'):

        # Obtain dataset
        if dataset == 'koa':
            csv_path = os.path.join(KOA_DATASET_DIR, f'training_{random_seed}', 'training.csv')
            full_df = pd.read_csv(csv_path)
        elif dataset == 'waterbirds':
            csv_path = os.path.join(WB_DATASET_DIR, f'training_{random_seed}', 'training.csv')
            full_df = pd.read_csv(csv_path)
        elif dataset == 'food_review':
            csv_path = os.path.join(FR_DATASET_DIR, f'training_{random_seed}.csv')
            full_df = pd.read_csv(csv_path)


        if num_samples:
            print("Experiment is only using a subset of the data...")
            full_df = full_df.iloc[0:num_samples]

        train_df = full_df[0: int(len(full_df) * const.TRAINING_SIZE)]
        val_df = full_df[int(len(full_df) * const.TRAINING_SIZE) :]


        # Get k-fold datasets
        tokenizer = None
        if dataset == 'koa' or dataset == 'koa_double':
            train_df = KneeOA(train_df)
            val_dataset = KneeOA(val_df)
        elif dataset == 'waterbirds':
            train_df = WaterBirds(train_df)
            val_dataset = WaterBirds(val_df)
        elif dataset == 'food_review':
            tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny', model_max_length=512)
            train_df = FoodReview(train_df, tokenizer)
            val_dataset = FoodReview(val_df, tokenizer)

        # Get k-fold data loaders
        train_loader = DataLoader(train_df, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

        # Train the teacher model
        teacher_model = train_teacher_model(dataset,
                                            train_loader,
                                            hp_args.get('teacher_lr'),
                                            hp_args.get('teacher_l2_cost'),
                                            hp_args.get('teacher_dropout'),
                                            num_epochs)

        # Evaluate the model
        score = evaluate_teacher(teacher_model, val_loader, dataset)

        # Save results to dataframe
        new_row = {}
        new_row['score'] = score
        new_row['training_set'] = random_seed
        new_row['teacher_lr'] = hp_args.get('teacher_lr')
        new_row['teacher_l2_cost'] = hp_args.get('teacher_l2_cost')
        new_row['teacher_dropout'] = hp_args.get('teacher_dropout')

        results_df.loc[len(results_df)] = new_row
        print(results_df)


# Process results
file_name = 'teacher' + '_' + dataset + '_' + str(num_samples) + '_' + str(time.time())

file_name_raw = 'raw_' + file_name
results_df.to_csv(os.path.join(CROSS_VAL_RESULTS_DIRECTORY, file_name_raw), index=False)

results_df_max = results_df.sort_values('score').groupby(['training_set']).tail(1).sort_values('training_set')
results_df_max.to_csv(os.path.join(CROSS_VAL_RESULTS_DIRECTORY, file_name), index=False)