import torch
from .datasets import KneeOA, WaterBirds, FoodReview
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
import time
import os
import argparse
import itertools
from .utils import train_mediator_model, evaluate_mediator_model
from sklearn.model_selection import KFold
from . import const
from transformers import AutoTokenizer, BertTokenizer
import numpy as np


# Train on GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using the following device: " + str(device))

# Disable wandb
os.environ["WANDB_DISABLED"] = "true"

# Detects any issues with back propogation
torch.autograd.set_detect_anomaly(True)

# Where to get data
WB_DATASET_DIR = const.WB_DATASET_DIR
KOA_DATASET_DIR = const.KOA_DATASET_DIR
FR_DATASET_DIR = const.FR_DATASET_DIR

# Where to save results
CROSS_VAL_RESULTS_DIRECTORY = const.CROSS_VAL_RESULTS_DIRECTORY

# Number of cross validation folds
NUM_FOLDS = const.NUM_FOLDS

# Number of workers
NUM_WORKERS = const.NUM_WORKERS


# Generate command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
                    type=str,
                    help='Dataset to use -- can be \"koa\", \"waterbirds\", or \"food_review\"')
parser.add_argument('--batch_size',
                    type=int,
                    help='Batch size for all models')
parser.add_argument('--num_epochs',
                    type=int,
                    help='Number of epochs for all models')
parser.add_argument('--mm_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the mediator model')
parser.add_argument('--mm_l2_weight',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the mediator model')
parser.add_argument('--num_samples',
                    type=int,
                    nargs='?',
                    help='Size of data to use in training')
parser.add_argument('--datasets_seeds',
                    type=int,
                    nargs='*',
                    default=[],
                    help='Number of datasets to perform cross validation over.')
parser.add_argument('--mm_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the teacher')

# Parse command line arguments
args = parser.parse_args()
dataset = args.dataset
batch_size = args.batch_size
num_epochs = args.num_epochs
mm_lr = list(args.mm_lr)
mm_l2_weight = list(args.mm_l2_weight)
mm_dropout = list(args.mm_dropout)
num_samples = args.num_samples
datasets_seeds = args.datasets_seeds


# Create dataframe to store the results and iterator of all hyperparameters to evaluate
columns = ['mm_lr', 'mm_l2_weight', 'mm_dropout']
hp_iterator = {'mm_lr': mm_lr, 'mm_l2_weight': mm_l2_weight, 'mm_dropout': mm_dropout}

# Keep track of results
columns.append('training_set')
columns.append('score')

# Add all columns to dataframe
results_df = pd.DataFrame(columns = columns)



# Create hyperparameter iterator
keys, values = zip(*hp_iterator.items())
hp_iterator = [dict(zip(keys, v)) for v in itertools.product(*values)]


# Iterate over multiple different simulations
for random_seed in tqdm(datasets_seeds, position=0, desc='Random Seeds'):

    # Iterate over multiple different simulations
    for hp_args in tqdm(hp_iterator, position=0, desc='Analysis'):

        # Obtain dataset
        if dataset == 'koa':
            csv_path = os.path.join(KOA_DATASET_DIR, f'training_{random_seed}', 'training.csv')
            full_df = pd.read_csv(csv_path)
        elif dataset == 'waterbirds':
            csv_path = os.path.join(WB_DATASET_DIR, f'training_{random_seed}', 'training.csv')
            full_df = pd.read_csv(csv_path)
        elif dataset == 'food_review':
            csv_path = os.path.join(FR_DATASET_DIR, f'training_{random_seed}.csv')
            full_df = pd.read_csv(csv_path)


        if num_samples:
            print("Experiment is only using a subset of the data...")
            full_df = full_df.iloc[0:num_samples]

        train_df = full_df[0: int(len(full_df) * const.TRAINING_SIZE)]
        val_df = full_df[int(len(full_df) * const.TRAINING_SIZE) :]

        # Get k-fold datasets
        tokenizer = None
        if dataset == 'koa':
            train_dataset = KneeOA(train_df)
            val_dataset = KneeOA(val_df)
        elif dataset == 'waterbirds':
            train_dataset = WaterBirds(train_df, include_med=True)
            val_dataset = WaterBirds(val_df, include_med=True)
        elif dataset == 'food_review':
            tokenizer = AutoTokenizer.from_pretrained('t5-small', model_max_length=512)
            train_dataset = FoodReview(train_df, tokenizer, include_med=True)
            val_dataset = FoodReview(train_df, tokenizer, include_med=True)


        # Get k-fold data loaders
        if dataset == 'food_review':
            train_ind = np.arange(0, int(len(full_df) * const.TRAINING_SIZE))
            train_loader = train_dataset.get_med_dataset(train_ind)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        else:
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

        # Train the mediator model
        mediator_model = train_mediator_model(dataset,
                                        train_loader,
                                        mm_lr=hp_args.get('mm_lr'),
                                        mm_l2_weight=hp_args.get('mm_l2_weight'),
                                        mm_dropout=hp_args.get('mm_dropout'),
                                        num_epochs=num_epochs,
                                        tokenizer=tokenizer,
                                        batch_size=batch_size
                                        )

        # Evaluate the model
        score = evaluate_mediator_model(mediator_model, val_loader, dataset, mm_tokenizer=tokenizer)

        # Save results to dataframe
        new_row = {}
        new_row['training_set'] = random_seed
        new_row['score'] = score
        new_row['mm_lr'] = hp_args.get('mm_lr')
        new_row['mm_l2_weight'] = hp_args.get('mm_l2_weight')
        new_row['mm_dropout'] = hp_args.get('mm_dropout')
        print(new_row)
        results_df.loc[len(results_df)] = new_row
        print(results_df)


# Process results
file_name = 'mediator_model' + '_' + dataset + '_' + str(num_samples) + '_'  + str(time.time())

file_name_raw = 'raw_' + file_name
results_df.to_csv(os.path.join(CROSS_VAL_RESULTS_DIRECTORY, file_name_raw), index=False)

file_name_max = 'max_' + file_name
results_df_max = results_df.sort_values('score').groupby(['training_set']).tail(1).sort_values('training_set')
results_df_max.to_csv(os.path.join(CROSS_VAL_RESULTS_DIRECTORY, file_name_max), index=False)

file_name_min = 'min_' + file_name
results_df_min = results_df.sort_values('score').groupby(['training_set']).head(1).sort_values('training_set')
results_df_min.to_csv(os.path.join(CROSS_VAL_RESULTS_DIRECTORY, file_name_min), index=False)