import torch
from .datasets import KneeOA, WaterBirds, FoodReview, TeacherDataset, MediatorDataset
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
import time
import os
import argparse
from transformers import BertTokenizer, AutoTokenizer, AutoProcessor
from .utils import train, evaluate, evaluate_teacher, train_teacher_model, get_teacher_logits, train_mediator_model, get_predicted_mediators
from sklearn.model_selection import KFold
from . import const
import numpy as np

# Train on GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using the following device: " + str(device))

# Detects any issues with back propogation
torch.autograd.set_detect_anomaly(True)

# Where to get data
WB_DATASET_DIR = const.WB_DATASET_DIR
WB_DOUBLE_DATASET_DIR = const.WB_DOUBLE_DATASET_DIR
KOA_DATASET_DIR = const.KOA_DATASET_DIR
KOA_DOUBLE_DATASET_DIR = const.KOA_DOUBLE_DATASET_DIR
FR_DATASET_DIR = const.FR_DATASET_DIR
FR_DOUBLE_DATASET_DIR = const.FR_DOUBLE_DATASET_DIR

# Where to save results
RESULTS_DIRECTORY = const.RESULTS_DIRECTORY

# Used for evaluation
TEST_DISTRIBUTIONS = const.TEST_DISTRIBUTIONS

# Number of cross fitting folds
NUM_FOLDS = const.NUM_FOLDS

# Number of workers
NUM_WORKERS = const.NUM_WORKERS


# Generate command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model_name',
                    type=str,
                    help='Type of model to train -- can be \"l2\", \"tipmi-nc\", \"tipmi-cf\", \"teacher\", \"sd\", \"gdro\", \"mmd\", or \"irm\"')
parser.add_argument('--dataset',
                    type=str,
                    help='Dataset to use -- can be \"koa\", \"koa_double\" \"waterbirds\", \"waterbirds_double\", \"food_review\", or \"food_review_double\"')
parser.add_argument('--datasets_seeds',
                    type=int,
                    nargs='*',
                    default=[],
                    help='Number of datasets to perform cross validation over.')
parser.add_argument('--batch_size',
                    type=int,
                    help='Batch size for all models')
parser.add_argument('--num_epochs',
                    type=int,
                    help='Number of epochs for all models')
parser.add_argument('--lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for l2, mmd, or gdro model')
parser.add_argument('--l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the l2, mmd, or gdro model')
parser.add_argument('--dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the model')
parser.add_argument('--student_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the student')
parser.add_argument('--student_l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the student')
parser.add_argument('--student_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the student')
parser.add_argument('--teacher_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the teacher')
parser.add_argument('--teacher_l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the teacher')
parser.add_argument('--teacher_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Dropout for the teacher')
parser.add_argument('--sigma',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Kernel bandwidth for MMD')
parser.add_argument('--alpha',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Cost of MMD regularization')
parser.add_argument('--mm_lr',
                    type=float,
                    nargs='*',
                    default=[],
                    help='Learning rate for the mediator model')
parser.add_argument('--mm_l2_cost',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the mediator model')
parser.add_argument('--mm_dropout',
                    type=float,
                    nargs='*',
                    default=[],
                    help='L2 cost for the mediator model')
parser.add_argument('--irm_penalty',
                    type=float,
                    nargs='*',
                    default=[],
                    help='IRM penalty')
parser.add_argument('--num_samples',
                    type=int,
                    nargs='?',
                    help='Size of data to use in training')


# Parse command line arguments
args = parser.parse_args()
model_name = args.model_name
dataset = args.dataset
datasets_seeds = args.datasets_seeds
batch_size = args.batch_size
num_epochs = args.num_epochs
num_samples = args.num_samples

# Create dataframe to store the results
columns = []
for i in TEST_DISTRIBUTIONS:
    columns.append(str(i))
results_df = pd.DataFrame(columns = columns)

for random_seed in tqdm(datasets_seeds, position=0, desc='Random Seeds'):

    # Get parameters for corresponding dataset
    if model_name == 'l2' or model_name == 'gdro':
        lr = args.lr[random_seed]
        l2_cost = args.l2_cost[random_seed]
        dropout = args.dropout[random_seed]
    elif model_name == 'tipmi-nc' or model_name == 'tipmi-cf':
        student_lr = args.student_lr[random_seed]
        student_l2_cost = args.student_l2_cost[random_seed]
        student_dropout = args.student_dropout[random_seed]
        teacher_lr = args.teacher_lr[random_seed]
        teacher_l2_cost = args.teacher_l2_cost[random_seed]
        teacher_dropout = args.teacher_dropout[random_seed]
    elif model_name == 'mmd':
        lr = args.lr[random_seed]
        sigma = args.sigma[random_seed]
        alpha = args.alpha[random_seed]
    elif model_name == 'kcit':
        lr = args.lr[random_seed]
        sigma = args.sigma[random_seed]
        alpha = args.alpha[random_seed]
    elif model_name == 'irm':
        lr = args.lr[random_seed]
        irm_penalty = args.irm_penalty[random_seed]
    elif model_name == 'mm' or model_name == 'mm-cf':
        lr = args.lr[random_seed]
        l2_cost = args.l2_cost[random_seed]
        dropout = args.dropout[random_seed]
        mm_lr = args.mm_lr[random_seed]
        mm_l2_cost = args.mm_l2_cost[random_seed]
        mm_dropout = args.mm_dropout[random_seed]
    elif model_name == 'sd':
        student_lr = args.student_lr[random_seed]
        student_l2_cost = args.student_l2_cost[random_seed]
        student_dropout = args.student_dropout[random_seed]

    # Obtain dataset
    if dataset == 'koa':
        csv_path = os.path.join(KOA_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'koa_double':
        csv_path = os.path.join(KOA_DOUBLE_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'waterbirds':
        csv_path = os.path.join(WB_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'waterbirds_double':
        csv_path = os.path.join(WB_DOUBLE_DATASET_DIR, f'training_{random_seed}', 'training.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'food_review':
        csv_path = os.path.join(FR_DATASET_DIR, f'training_{random_seed}.csv')
        full_df = pd.read_csv(csv_path)
    elif dataset == 'food_review_double':
        csv_path = os.path.join(FR_DOUBLE_DATASET_DIR, f'training_{random_seed}.csv')
        full_df = pd.read_csv(csv_path)

    if num_samples:
        print("Experiment is only using a subset of the data...")
        full_df = full_df.iloc[0:num_samples]


    # Get dataset
    include_extra = model_name == 'tipmi-nc' or model_name == 'sd' or model_name == 'mm'
    include_group = model_name == 'gdro'
    include_aux = model_name == 'mmd' or model_name == 'irm'
    include_med = model_name == 'mm' or model_name == 'mm-cf'
    
    mm_tokenizer = None
    tokenizer = None
    if dataset == 'koa' or dataset == 'koa_double':
        train_dataset = KneeOA(full_df, include_group=include_group, include_aux=include_aux)
        teacher_dataset = KneeOA(full_df, include_extra=include_extra)
        mediator_dataset = KneeOA(full_df)
    elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
        train_dataset = WaterBirds(full_df, include_group=include_group, include_aux=include_aux)
        teacher_dataset = WaterBirds(full_df, include_extra=include_extra)
        mediator_dataset = WaterBirds(full_df, include_med=include_med)
    elif dataset == 'food_review' or dataset == 'food_review_double':
        tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny', model_max_length=512)
        mm_tokenizer = AutoTokenizer.from_pretrained('t5-small', model_max_length=512)
        train_dataset = FoodReview(full_df, tokenizer, include_group=include_group, include_aux=include_aux, include_med=include_med)
        mediator_dataset = FoodReview(full_df, mm_tokenizer, include_med=include_med)
        teacher_dataset = FoodReview(full_df, tokenizer, include_extra=include_extra)



    # Create data loader for l2 or the teacher
    if model_name == 'l2' or model_name == 'gdro' or model_name == 'mmd' or model_name == 'irm' or model_name == 'kcit':

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    # Create mediator_model and data loaders
    mediator_model = None
    if model_name == 'mm':

        mm_train_ind = np.arange(0, len(full_df))
        if dataset == 'food_review':
            mm_train_loader = mediator_dataset.get_med_dataset(mm_train_ind)
            train_loader = DataLoader(mediator_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        else:
            mm_train_loader = DataLoader(mediator_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
            train_loader = DataLoader(mediator_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

        # Train the mediator model
        mediator_model = train_mediator_model(dataset,
                                        mm_train_loader,
                                        mm_lr=mm_lr,
                                        mm_l2_weight=mm_l2_cost,
                                        mm_dropout=mm_dropout,
                                        num_epochs=num_epochs,
                                        tokenizer=tokenizer,
                                        batch_size=batch_size
                                        )
        # Get predicted mediators
        mediator_loader = DataLoader(teacher_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        extra_info, mediators, labels = get_predicted_mediators(mediator_model, mediator_loader, dataset, mm_tokenizer=mm_tokenizer)
        # Get data loader
        train_dataset = MediatorDataset(extra_info, mediators, labels, dataset, tokenizer=tokenizer) 
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    if model_name == 'mm-cf':

        kfold = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
        extra_info_list = []
        mediators_list = []
        labels_list = []
        for fold, (mm_train_ind, mm_val_ind) in enumerate(kfold.split(full_df)):

            # Get k-fold training data
            train_k_df = full_df.iloc[mm_train_ind]
            val_df = full_df.iloc[mm_val_ind]

            # Get k-fold datasets
            if dataset == 'koa' or dataset == 'koa_double':
                train_k_dataset = KneeOA(train_k_df)
                train_k_loader = DataLoader(train_k_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
                val_dataset = KneeOA(val_df, include_extra = True)
            elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
                train_k_dataset = WaterBirds(train_k_df, include_med=True)
                train_k_loader = DataLoader(train_k_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
                val_dataset = WaterBirds(val_df, include_extra = True)
            elif dataset == 'food_review' or dataset == 'food_review_double':
                train_k_dataset = FoodReview(train_k_df, tokenizer, include_med=True)
                train_k_loader = mediator_dataset.get_med_dataset(mm_train_ind)
                val_dataset = FoodReview(val_df, tokenizer, include_extra = True, include_med=True)

            # Get k-fold data loaders
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

            # Train the mediator model
            mediator_model = train_mediator_model(dataset,
                                            train_k_loader,
                                            mm_lr=mm_lr,
                                            mm_l2_weight=mm_l2_cost,
                                            mm_dropout=mm_dropout,
                                            num_epochs=num_epochs,
                                            tokenizer=tokenizer,
                                            batch_size=batch_size
                                            )

            # Get predicted mediators
            extra_info, mediators, labels = get_predicted_mediators(mediator_model, val_loader, dataset, mm_tokenizer=mm_tokenizer)

            extra_info_list = extra_info_list + extra_info
            if dataset != 'food_review' and dataset != 'food_review_double':
                mediators_list.append(mediators)
            else:
                mediators_list = mediators_list + mediators
            labels_list.append(labels)


        # Create the teacher data loader
        if dataset != 'food_review' and dataset != 'food_review_double':
            mediators_list = torch.concat(mediators_list, dim=0)
        labels_list = torch.concat(labels_list, dim=0)

        # Get data loader
        train_dataset = MediatorDataset(extra_info_list, mediators_list, labels_list, dataset, tokenizer=tokenizer) 
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)



    # Create data loader for tipmi-nc
    elif model_name == 'tipmi-nc':

        # Train the teacher model with a temporary data loader
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        teacher_loader = DataLoader(teacher_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        teacher_model = train_teacher_model(dataset, train_loader, teacher_lr, teacher_l2_cost, teacher_dropout, num_epochs)
        # Get logits
        extra_info, teacher_logits = get_teacher_logits(teacher_model, teacher_loader, dataset)
        # Get data loader
        train_dataset = TeacherDataset(extra_info, teacher_logits, dataset, tokenizer=tokenizer)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    # Create data loader for sd
    elif model_name == 'sd':

        # Train the teacher model with a temporary data loader
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        teacher_loader = DataLoader(teacher_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
        teacher_model = train_teacher_model(dataset, train_loader, student_lr, student_l2_cost, student_dropout, num_epochs, True)
        # Get logits
        extra_info, teacher_logits = get_teacher_logits(teacher_model, teacher_loader, dataset, True)
        # Get data loader
        train_dataset = TeacherDataset(extra_info, teacher_logits, dataset, tokenizer=tokenizer)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)


    # Create data loader for tipmi-cf
    elif model_name == 'tipmi-cf':

        kfold = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
        extra_info_list = []
        teacher_logits_list = []
        for fold, (teacher_train_ind, teacher_val_ind) in enumerate(kfold.split(full_df)):

            # Get k-fold training data
            train_k_df = full_df.iloc[teacher_train_ind]
            val_df = full_df.iloc[teacher_val_ind]

            # Get k-fold datasets
            if dataset == 'koa' or dataset == 'koa_double':
                train_k_dataset = KneeOA(train_k_df)
                val_dataset = KneeOA(val_df, include_extra = True)
            elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
                train_k_dataset = WaterBirds(train_k_df)
                val_dataset = WaterBirds(val_df, include_extra = True)
            elif dataset == 'food_review' or dataset == 'food_review_double':
                train_k_dataset = FoodReview(train_k_df, tokenizer)
                val_dataset = FoodReview(val_df, tokenizer, include_extra = True)

            # Get k-fold data loaders
            train_k_loader = DataLoader(train_k_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

            # Train the teacher model
            teacher_model = train_teacher_model(dataset, train_k_loader, teacher_lr, teacher_l2_cost, teacher_dropout, num_epochs)
            # Get the teacher training data
            extra_info, teacher_logits = get_teacher_logits(teacher_model, val_loader, dataset)
            extra_info_list = extra_info_list + extra_info
            teacher_logits_list.append(teacher_logits)

        # Create the teacher data loader
        teacher_logits_list = torch.concat(teacher_logits_list, dim=0)
        train_dataset = TeacherDataset(extra_info_list, teacher_logits_list, dataset, tokenizer=tokenizer)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)



    # Train the model
    if model_name == 'teacher':
        model = train_teacher_model(dataset,
                                    train_loader,
                                    teacher_lr,
                                    teacher_l2_cost,
                                    teacher_dropout,
                                    num_epochs
                                    )
    elif model_name == 'l2':
        model = train(model_name,
                        dataset,
                        num_epochs,
                        train_loader,
                        lr=lr,
                        l2_cost=l2_cost,
                        dropout=dropout
                        )
    elif model_name == 'gdro':
        dataset_info = train_dataset.get_group_info()
        model = train(model_name,
                        dataset,
                        num_epochs,
                        train_loader,
                        lr=lr,
                        l2_cost=l2_cost,
                        dropout=dropout,
                        dataset_info=dataset_info
                        )
    elif model_name == 'mmd':
        mmd_info = {'sigma': sigma, 'alpha': alpha}
        model = train(model_name,
                        dataset,
                        num_epochs,
                        train_loader,
                        lr=lr,
                        l2_cost=0,
                        dropout=0,
                        mmd_info=mmd_info
                        )
    elif model_name == 'kcit':
        mmd_info = {'sigma': sigma, 'alpha': alpha}
        model = train(model_name,
                        dataset,
                        num_epochs,
                        train_loader,
                        lr=lr,
                        l2_cost=0,
                        dropout=0,
                        mmd_info=mmd_info
                        )  
    elif model_name == 'tipmi-nc' or model_name == 'tipmi-cf' or model_name == 'sd':
        model = train(model_name,
                        dataset,
                        num_epochs,
                        train_loader,
                        lr=student_lr,
                        l2_cost=student_l2_cost,
                        dropout=student_dropout
                        )
    elif model_name == 'mm' or model_name == 'mm-cf':
        model = train(model_name,
                        dataset,
                        num_epochs,
                        train_loader,
                        lr=lr,
                        l2_cost=l2_cost,
                        dropout=dropout,
                        mediator_model=mediator_model,
                        mm_tokenizer=mm_tokenizer,
                        tokenizer=tokenizer
                        )
    elif model_name == 'irm':
        model = train(model_name,
                        dataset,
                        num_epochs,
                        train_loader,
                        lr=lr,
                        l2_cost=0,
                        dropout=0,
                        irm_penalty=irm_penalty)

    # Analyze the model over various distributions
    new_row = {}
    for i in tqdm(range(len(TEST_DISTRIBUTIONS)), position=1, desc='Distributions', leave=False):

        # Obtain dataset that follows the specified distribution
        if dataset == 'koa':
            csv_path = os.path.join(KOA_DATASET_DIR, f'testing_{random_seed}', str(TEST_DISTRIBUTIONS[i]), 'testing.csv')
            test_df = pd.read_csv(csv_path)
            test_dataset = KneeOA(test_df)
        elif dataset == 'koa_double':
            csv_path = os.path.join(KOA_DOUBLE_DATASET_DIR, f'testing_{random_seed}', str(TEST_DISTRIBUTIONS[i]), 'testing.csv')
            test_df = pd.read_csv(csv_path)
            test_dataset = KneeOA(test_df)
        elif dataset == 'waterbirds':
            csv_path = os.path.join(WB_DATASET_DIR, f'testing_{random_seed}', str(TEST_DISTRIBUTIONS[i]), 'testing.csv')
            test_df = pd.read_csv(csv_path)
            test_dataset = WaterBirds(test_df, include_med=include_med)
        elif dataset == 'waterbirds_double':
            csv_path = os.path.join(WB_DOUBLE_DATASET_DIR, f'testing_{random_seed}', str(TEST_DISTRIBUTIONS[i]), 'testing.csv')
            test_df = pd.read_csv(csv_path)
            test_dataset = WaterBirds(test_df, include_med=include_med)
        elif dataset == 'food_review':
            csv_path = os.path.join(FR_DATASET_DIR, f'testing_{random_seed}', f'testing_{TEST_DISTRIBUTIONS[i]}.csv')
            test_df = pd.read_csv(csv_path)
            test_dataset = FoodReview(test_df, tokenizer, include_med=include_med)
        elif dataset == 'food_review_double':
            csv_path = os.path.join(FR_DOUBLE_DATASET_DIR, f'testing_{random_seed}', f'testing_{TEST_DISTRIBUTIONS[i]}.csv')
            test_df = pd.read_csv(csv_path)
            test_dataset = FoodReview(test_df, tokenizer, include_med=include_med)

        # Create data loader
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

        # Obtain results
        if model_name == 'teacher':
            score = evaluate_teacher(model, test_loader, dataset)
        else:
            score = evaluate(model, test_loader, dataset, mediator_model=mediator_model, tokenizer=tokenizer, mm_tokenizer=mm_tokenizer)
        new_row[str(TEST_DISTRIBUTIONS[i])] = score

    results_df.loc[len(results_df)] = new_row
    print(results_df)

# Save results
file_name = os.path.join(RESULTS_DIRECTORY, model_name + '_' + dataset + '_' + str(num_samples) + '_' + str(time.time()))
results_df.to_csv(file_name, index=False)
print("Resuts: ")
print(results_df)