import torch
from .datasets import KneeOA, WaterBirds, FoodReview, TeacherDataset, MediatorDataset
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
import time
import os
import argparse
import itertools
from .utils import train, evaluate, train_teacher_model, train_mediator_model, get_predicted_mediators, get_teacher_logits, evaluate_worst_group_accuracy
from sklearn.model_selection import KFold
from . import const
from transformers import BertTokenizer, AutoTokenizer, AutoProcessor
import numpy as np


# Train on GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using the following device: " + str(device))

# Detects any issues with back propogation
torch.autograd.set_detect_anomaly(True)

# Where to get data
WB_DATASET_DIR = const.WB_DATASET_DIR
WB_DOUBLE_DATASET_DIR = const.WB_DOUBLE_DATASET_DIR
KOA_DATASET_DIR = const.KOA_DATASET_DIR
KOA_DOUBLE_DATASET_DIR = const.KOA_DOUBLE_DATASET_DIR
FR_DATASET_DIR = const.FR_DATASET_DIR
FR_DOUBLE_DATASET_DIR = const.FR_DOUBLE_DATASET_DIR

# Where to save results
CROSS_VAL_RESULTS_DIRECTORY = const.CROSS_VAL_RESULTS_DIRECTORY

# Number of cross fitting folds
NUM_FOLDS = const.NUM_FOLDS

# Number of workers
NUM_WORKERS = const.NUM_WORKERS


# Generate command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model_name',
                    type=str,
                    help='Type of model to train -- can be \"l2\", \"tipmi-nc\", \"tipmi-cf\", \"teacher\", \"sd\", \"gdro\", \"mmd\", \"irm\", \"mm\", or \"mm-cf\"')
parser.add_argument('--dataset',
                    type=str,
                    help='Dataset to use -- can be \"koa\", \"koa_double\" \"waterbirds\", \"waterbirds_double\", \"food_review\", \"food_review_double\"')
parser.add_argument('--datasets_seeds',
                    type=int,
                    nargs='*',
                    default=[],
                    help='Number of datasets to perform cross validation over.')
parser.add_argument('--batch_size',
                    type=int,
                    help='Batch size for all models')
parser.add_argument('--num_epochs',
                    type=int,
                    help='Number of epochs for all models')
parser.add_argument('--lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for l2, mmd, or gdro model')
parser.add_argument('--l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the l2, mmd, or gdro model')
parser.add_argument('--dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the model')
parser.add_argument('--student_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the student')
parser.add_argument('--student_l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the student')
parser.add_argument('--student_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the student')
parser.add_argument('--teacher_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the teacher')
parser.add_argument('--teacher_l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the teacher')
parser.add_argument('--teacher_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the teacher')
parser.add_argument('--sigma',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Kernel bandwidth')
parser.add_argument('--alpha',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Cost of MMD or KCIT regularization')
parser.add_argument('--mm_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the mediator model')
parser.add_argument('--mm_l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the mediator model')
parser.add_argument('--mm_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the mediator model')
parser.add_argument('--irm_penalty',
                    type=float,
                    nargs='*',
                    default=[],
                    help='IRM penalty')
parser.add_argument('--num_samples',
                    type=int,
                    nargs='?',
                    help='Size of data to use in training')

# Parse command line arguments
args = parser.parse_args()
model_name = args.model_name
dataset = args.dataset
datasets_seeds = args.datasets_seeds
batch_size = args.batch_size
num_epochs = args.num_epochs
lr = list(args.lr)
l2_cost = list(args.l2_cost)
dropout = list(args.dropout)
student_lr = list(args.student_lr)
student_l2_cost = list(args.student_l2_cost)
student_dropout = list(args.student_dropout)
teacher_lr = list(args.teacher_lr)
teacher_l2_cost = list(args.teacher_l2_cost)
teacher_dropout = list(args.teacher_dropout)
sigma = list(args.sigma)
alpha = list(args.alpha)
mm_lr = list(args.mm_lr)
mm_l2_cost = list(args.mm_l2_cost)
mm_dropout = list(args.mm_dropout)
irm_penalty = list(args.irm_penalty)
num_samples = args.num_samples

# Create dataframe to store the results and iterator of all hyperparameters to evaluate
if model_name == 'l2' or model_name == 'gdro' or model_name == 'mm' or model_name == 'mm-cf':
    columns = ['lr', 'l2_cost', 'dropout']
    hp_iterator = {'lr': lr, 'l2_cost': l2_cost, 'dropout': dropout}

elif model_name == 'mmd' or model_name == 'kcit':
    columns = ['lr', 'sigma', 'alpha']
    hp_iterator = {'lr': lr, 'sigma': sigma, 'alpha': alpha}

elif model_name == 'tipmi-nc' or model_name == 'tipmi-cf' or model_name == 'sd':
    columns = ['student_lr', 'student_l2_cost', 'student_dropout']
    hp_iterator = {'student_lr': student_lr, 'student_l2_cost': student_l2_cost, 'student_dropout': student_dropout}

elif model_name == 'irm':
    columns = ['lr', 'irm_penalty']
    hp_iterator = {'lr': lr, 'irm_penalty': irm_penalty}

# Keep track of results
columns.append('training_set')
columns.append('score')
columns.append('worst_group_score')

# Add all columns to dataframe
results_df = pd.DataFrame(columns = columns)



# Create hyperparameter iterator
keys, values = zip(*hp_iterator.items())
hp_iterator = [dict(zip(keys, v)) for v in itertools.product(*values)]


# Iterate over multiple different simulations
for random_seed in tqdm(datasets_seeds, position=0, desc='Random Seeds'):

    # Obtain dataset
    if dataset == 'koa':
        csv_path = os.path.join(KOA_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'koa_double':
        csv_path = os.path.join(KOA_DOUBLE_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'waterbirds':
        csv_path = os.path.join(WB_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'waterbirds_double':
        csv_path = os.path.join(WB_DOUBLE_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'food_review':
        csv_path = os.path.join(FR_DATASET_DIR, f'training_{random_seed}.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'food_review_double':
        csv_path = os.path.join(FR_DOUBLE_DATASET_DIR, f'training_{random_seed}.csv')
        full_df = pd.read_csv(csv_path)
    
    if num_samples:
        print("Experiment is only using a subset of the data...")
        full_df = full_df.iloc[0:num_samples]

    train_df = full_df[0: int(len(full_df) * const.TRAINING_SIZE)]
    test_df = full_df[int(len(full_df) * const.TRAINING_SIZE) :]

    # Get datasets
    include_extra = model_name == 'tipmi-nc' or model_name == 'sd' or model_name == 'mm'
    include_group = model_name == 'gdro'
    include_aux = model_name == 'mmd' or model_name == 'irm'
    include_med = model_name == 'mm' or model_name == 'mm-cf'
    tokenizer = None
    mm_tokenizer = None
    if dataset == 'koa' or dataset == 'koa_double':
        train_dataset = KneeOA(train_df, include_group=include_group, include_aux=include_aux)
        teacher_dataset = KneeOA(train_df, include_extra=include_extra)
        mediator_dataset = KneeOA(train_df)
        test_dataset = KneeOA(test_df)
        test_group_dataset = KneeOA(test_df, include_group=True)
    elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
        train_dataset = WaterBirds(train_df, include_group=include_group, include_aux=include_aux)
        teacher_dataset = WaterBirds(train_df, include_extra=include_extra)
        mediator_dataset = WaterBirds(train_df, include_med=include_med)
        test_dataset = WaterBirds(test_df)
        test_group_dataset = WaterBirds(test_df, include_group=True)
    elif dataset == 'food_review' or dataset == 'food_review_double':
        tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny', model_max_length=512)
        mm_tokenizer = AutoTokenizer.from_pretrained('t5-small', model_max_length=512)
        train_dataset = FoodReview(train_df, tokenizer, include_group=include_group, include_aux=include_aux, include_med=include_med)
        teacher_dataset = FoodReview(train_df, tokenizer, include_extra=include_extra)
        mediator_dataset = FoodReview(train_df, mm_tokenizer, include_med=include_med)
        test_dataset = FoodReview(test_df, tokenizer, include_med=include_med)
        test_group_dataset = FoodReview(test_df, tokenizer, include_med=include_med, include_group=True)

    # Create test data loader
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    test_group_loader = DataLoader(test_group_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    group_info = test_dataset.get_group_info()

    # Create data loader for l2 or the teacher
    if model_name == 'l2' or model_name == 'teacher' or model_name == 'gdro' or model_name == 'mmd' or model_name == 'kcit' or model_name == 'irm':

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)


    # Create mediator_model and data loaders
    mediator_model = None
    if model_name == 'mm':
        
        mm_train_ind = np.arange(0, int(len(full_df) * const.TRAINING_SIZE))
        if dataset == 'food_review' or dataset == 'food_review_double':
            mm_train_loader = mediator_dataset.get_med_dataset(mm_train_ind)
        else:
            mm_train_loader = DataLoader(mediator_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

        # Train the mediator model
        mediator_model = train_mediator_model(dataset,
                                        mm_train_loader,
                                        mm_lr=mm_lr[random_seed],
                                        mm_l2_weight=mm_l2_cost[random_seed],
                                        mm_dropout=mm_dropout[random_seed],
                                        num_epochs=num_epochs,
                                        tokenizer=tokenizer,
                                        batch_size=batch_size
                                        )

        # Get predicted mediators
        mediator_loader = DataLoader(teacher_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        extra_info, mediators, labels = get_predicted_mediators(mediator_model, mediator_loader, dataset, mm_tokenizer=None)
        # Get data loader
        train_dataset = MediatorDataset(extra_info, mediators, labels, dataset, tokenizer=tokenizer) 
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    

    if model_name == 'mm-cf':

        kfold = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
        extra_info_list = []
        mediators_list = []
        labels_list = []
        for fold, (mm_train_ind, mm_val_ind) in enumerate(kfold.split(train_df)):

            # Get k-fold training data
            train_k_df = train_df.iloc[mm_train_ind]
            val_df = train_df.iloc[mm_val_ind]

            # Get k-fold datasets
            if dataset == 'koa' or dataset == 'koa_double':
                train_k_dataset = KneeOA(train_k_df)
                train_k_loader = DataLoader(train_k_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
                val_dataset = KneeOA(val_df, include_extra = True)
            elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
                train_k_dataset = WaterBirds(train_k_df, include_med=True)
                train_k_loader = DataLoader(train_k_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
                val_dataset = WaterBirds(val_df, include_extra = True)
            elif dataset == 'food_review' or dataset == 'food_review_double':
                train_k_dataset = FoodReview(train_k_df, tokenizer, include_med=True)
                train_k_loader = mediator_dataset.get_med_dataset(mm_train_ind)
                val_dataset = FoodReview(val_df, tokenizer, include_extra = True)

            # Get k-fold data loaders
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

            # Train the mediator model
            mediator_model = train_mediator_model(dataset,
                                            train_k_loader,
                                            mm_lr=mm_lr[random_seed],
                                            mm_l2_weight=mm_l2_cost[random_seed],
                                            mm_dropout=mm_dropout[random_seed],
                                            num_epochs=num_epochs,
                                            tokenizer=tokenizer,
                                            batch_size=batch_size
                                            )

            # Get predicted mediators
            extra_info, mediators, labels = get_predicted_mediators(mediator_model, val_loader, dataset, mm_tokenizer=mm_tokenizer)

            extra_info_list = extra_info_list + extra_info
            if dataset != 'food_review' and dataset != 'food_review_double':
                mediators_list.append(mediators)
            else:
                mediators_list = mediators_list + mediators
            labels_list.append(labels)


        # Create the teacher data loader
        if dataset != 'food_review' and dataset != 'food_review_double':
            mediators_list = torch.concat(mediators_list, dim=0)
        labels_list = torch.concat(labels_list, dim=0)

        # Get data loader
        train_dataset = MediatorDataset(extra_info_list, mediators_list, labels_list, dataset, tokenizer=tokenizer) 
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)



    # Create data loader for tipmi-nc or sd
    elif model_name == 'tipmi-nc':

        # Train the teacher model with a temporary data loader
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        teacher_loader = DataLoader(teacher_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        teacher_model = train_teacher_model(dataset, train_loader, teacher_lr[random_seed], teacher_l2_cost[random_seed], teacher_dropout[random_seed], num_epochs)
        # Get logits
        extra_info, teacher_logits = get_teacher_logits(teacher_model, teacher_loader, dataset)
        # Get data loader
        train_dataset = TeacherDataset(extra_info, teacher_logits, dataset, tokenizer=tokenizer)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)


    # Create data loader for tipmi-cf
    elif model_name == 'tipmi-cf':

        kfold = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
        extra_info_list = []
        teacher_logits_list = []
        for fold, (teacher_train_ind, teacher_val_ind) in enumerate(kfold.split(train_df)):

            # Get k-fold training data
            train_k_df = train_df.iloc[teacher_train_ind]
            val_df = train_df.iloc[teacher_val_ind]

            # Get k-fold datasets
            if dataset == 'koa' or dataset == 'koa_double':
                train_k_dataset = KneeOA(train_k_df)
                teacher_dataset = KneeOA(val_df, include_extra = True)
            elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
                train_k_dataset = WaterBirds(train_k_df)
                teacher_dataset = WaterBirds(val_df, include_extra = True)
            elif dataset == 'food_review' or dataset == 'food_review_double':
                train_k_dataset = FoodReview(train_k_df, tokenizer)
                teacher_dataset = FoodReview(val_df, tokenizer, include_extra = True)

            # Get k-fold data loaders
            train_k_loader = DataLoader(train_k_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
            val_loader = DataLoader(teacher_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

            # Train the teacher model
            teacher_model = train_teacher_model(dataset, train_k_loader, teacher_lr[random_seed], teacher_l2_cost[random_seed], teacher_dropout[random_seed], num_epochs)
            # Get the teacher training data
            extra_info, teacher_logits = get_teacher_logits(teacher_model, val_loader, dataset)
            extra_info_list = extra_info_list + extra_info
            teacher_logits_list.append(teacher_logits)

        # Create the teacher data loader
        teacher_logits_list = torch.concat(teacher_logits_list, dim=0)
        train_dataset = TeacherDataset(extra_info_list, teacher_logits_list, dataset, tokenizer=tokenizer)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    for hp_args in tqdm(hp_iterator, position=0, desc='Analysis'):

        # Create new data loader for sd
        if model_name == 'sd':
            # Train the teacher model with a temporary data loader
            sd_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
            sd_val_loader = DataLoader(teacher_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
            teacher_model = train_teacher_model(dataset, sd_train_loader, hp_args.get('student_lr'), hp_args.get('student_l2_cost'), hp_args.get('student_dropout'), num_epochs, True)
            # Get logits
            extra_info, teacher_logits = get_teacher_logits(teacher_model, sd_val_loader, dataset, True)
            # Get data loader
            sd_train_dataset = TeacherDataset(extra_info, teacher_logits, dataset, tokenizer=tokenizer)
            sd_train_loader = DataLoader(sd_train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

        # Train the model
        if model_name == 'l2':
            model = train(model_name,
                            dataset,
                            num_epochs,
                            train_loader,
                            lr=hp_args.get('lr'),
                            l2_cost=hp_args.get('l2_cost'),
                            dropout=hp_args.get('dropout')
                            )
        elif model_name == 'gdro':
            dataset_info = train_dataset.get_group_info()
            model = train(model_name,
                            dataset,
                            num_epochs,
                            train_loader,
                            lr=hp_args.get('lr'),
                            l2_cost=hp_args.get('l2_cost'),
                            dropout=hp_args.get('dropout'),
                            dataset_info=dataset_info
                            )
        elif model_name == 'mmd':
            mmd_info = {'sigma': hp_args.get('sigma'), 'alpha': hp_args.get('alpha')}
            model = train(model_name,
                            dataset,
                            num_epochs,
                            train_loader,
                            lr=hp_args.get('lr'),
                            l2_cost=0,
                            dropout=0,
                            mmd_info=mmd_info
                            )  
        elif model_name == 'kcit':
            mmd_info = {'sigma': hp_args.get('sigma'), 'alpha': hp_args.get('alpha')}
            model = train(model_name,
                            dataset,
                            num_epochs,
                            train_loader,
                            lr=hp_args.get('lr'),
                            l2_cost=0,
                            dropout=0,
                            mmd_info=mmd_info
                            )  
        elif model_name == 'tipmi-nc' or model_name == 'tipmi-cf' :
            model = train(model_name,
                            dataset,
                            num_epochs,
                            train_loader,
                            lr=hp_args.get('student_lr'),
                            l2_cost=hp_args.get('student_l2_cost'),
                            dropout=hp_args.get('student_dropout')
                            )
        elif model_name == 'sd':
            model = train(model_name,
                            dataset,
                            num_epochs,
                            sd_train_loader,
                            lr=hp_args.get('student_lr'),
                            l2_cost=hp_args.get('student_l2_cost'),
                            dropout=hp_args.get('student_dropout')
                            )
        elif model_name == 'mm' or model_name == 'mm-cf':
            model = train(model_name,
                            dataset,
                            num_epochs,
                            train_loader,
                            lr=hp_args.get('lr'),
                            l2_cost=hp_args.get('l2_cost'),
                            dropout=hp_args.get('dropout'),
                            mediator_model=mediator_model,
                            mm_tokenizer=mm_tokenizer,
                            tokenizer=tokenizer
                            )
        elif model_name == 'irm':
            model = train(model_name,
                            dataset,
                            num_epochs,
                            train_loader,
                            lr=hp_args.get('lr'),
                            l2_cost=0,
                            dropout=0,
                            irm_penalty=hp_args.get('irm_penalty')
                            )

        # Obtain results
        score = evaluate(model, test_loader, dataset, mediator_model=mediator_model, tokenizer=tokenizer, mm_tokenizer=mm_tokenizer)
        worst_group_score = evaluate_worst_group_accuracy(model, test_group_loader, group_info, dataset, mediator_model=mediator_model, tokenizer=tokenizer, mm_tokenizer=mm_tokenizer)
        
        # Save results to dataframe
        new_row = {}
        new_row['training_set'] = random_seed
        new_row['score'] = score
        new_row['worst_group_score'] = worst_group_score

        if model_name == 'l2' or model_name == 'gdro' or model_name == 'mm' or model_name == 'mm-cf':
            new_row['lr'] = hp_args.get('lr')
            new_row['l2_cost'] = hp_args.get('l2_cost')
            new_row['dropout'] = hp_args.get('dropout')

        elif model_name == 'mmd':
            new_row['lr'] = hp_args.get('lr')
            new_row['sigma'] = hp_args.get('sigma')
            new_row['alpha'] = hp_args.get('alpha')

        elif model_name == 'tipmi-nc' or model_name == 'tipmi-cf' or model_name == 'sd':
            new_row['student_lr'] = hp_args.get('student_lr')
            new_row['student_l2_cost'] = hp_args.get('student_l2_cost')
            new_row['student_dropout'] = hp_args.get('student_dropout')

        elif model_name == 'irm':
            new_row['lr'] = hp_args.get('lr')
            new_row['irm_penalty'] = hp_args.get('irm_penalty')

        elif model_name == 'kcit':
            new_row['lr'] = hp_args.get('lr')
            new_row['sigma'] = hp_args.get('sigma')
            new_row['alpha'] = hp_args.get('alpha')
        results_df.loc[len(results_df)] = new_row
        print(results_df)



# Process results
file_name = model_name + '_' + dataset + '_' + str(num_samples) + '_' + str(time.time())

file_name_raw = 'raw_' + file_name

results_df.to_csv(os.path.join(CROSS_VAL_RESULTS_DIRECTORY, file_name_raw), index=False)
if model_name == 'gdro' or model_name == 'mmd' or model_name == 'irm':
    results_df_max = results_df.sort_values('worst_group_score').groupby(['training_set']).tail(1).sort_values('training_set')
else:
    results_df_max = results_df.sort_values('score').groupby(['training_set']).tail(1).sort_values('training_set')
results_df_max.to_csv(os.path.join(CROSS_VAL_RESULTS_DIRECTORY, file_name), index=False)

print("Results: ")
print(results_df)
print("Max Results: ")
print(results_df_max)