import torch
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
from tqdm import tqdm
from .models import BinaryPreTrainedNet, JswNet, BertClassifier, JswNetMediator
from .loss import LossComputer, MMD_Loss, KCIT_Loss, IRM
import numpy as np
import warnings
import segmentation_models_pytorch as smp
from transformers import DataCollatorForSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate as evaluate_rogue
from more_itertools import sliced
from transformers import pipeline
import os
from .datasets import TeacherDataset
from torch.utils.tensorboard import SummaryWriter


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Ignore these warnings that occur when calculating worst group accuracy
warnings.filterwarnings("ignore", "Mean of empty slice")
warnings.filterwarnings("ignore", "invalid value encountered in scalar divide")

def train(
        model_name,
        dataset_name,
        num_epochs,
        train_loader,
        lr,
        l2_cost,
        dropout,
        dataset_info = None,
        mmd_info = None,
        mediator_model = None,
        mm_tokenizer = None,
        tokenizer = None,
        irm_penalty = None
        ):
    """
    Used for creating and training a model
    :param model_name: The type of model that is being trained
    :param dataset_name: Name of the dataset being used
    :param num_epochs: Number of epochs the model will be trained for
    :param train_loader: The training dataset loader
    :param lr: Learning rate
    :param l2_cost: The L2 regularization cost
    :param dataset_info: Information used for calculating GDRO loss
    :param mmd_info: Info used for calculating Weighted MMD loss
    """

    writer = SummaryWriter()    
    # Get model
    if dataset_name == 'food_review' or dataset_name == 'food_review_double':
        model = BertClassifier(dropout, dropout).to(device)
    elif (dataset_name == 'koa' or dataset_name == 'koa_double') and (model_name == 'mm' or model_name == 'mm-cf'):
        model = JswNet().to(device)
    else:
        model = BinaryPreTrainedNet('resnet').to(device)
    model.train()
    # Obtain optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_cost)

    # Used for GDRO and IRM
    loss_computer = None
    if model_name == 'gdro':
        loss_function = torch.nn.BCELoss(reduction = 'none')
        loss_computer = LossComputer(loss_function, True, dataset_info, alpha=1)
    elif model_name == 'irm':
        loss_function = torch.nn.BCELoss()
        loss_computer = IRM(loss_function, irm_penalty)

    # Train the model over n epochs
    iteration = 0
    for _ in tqdm(range(num_epochs), position=2, leave=False, desc='Epochs'):

        # Iterate through dataset
        for batch in train_loader:
            loss = get_loss(
                            model,
                            model_name,
                            dataset_name,
                            batch,
                            loss_computer=loss_computer,
                            mmd_info=mmd_info,
                            mediator_model=mediator_model,
                            mm_tokenizer = mm_tokenizer,
                            tokenizer = tokenizer
                            )
            writer.add_scalar("Loss/train", loss, iteration)
            iteration += 1

            # Perform backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    model.eval()
    writer.flush()
    writer.close()

    return model



def get_loss(
        model,
        model_name,
        dataset_name,
        batch,
        loss_computer,
        mmd_info,
        mediator_model,
        mm_tokenizer,
        tokenizer
        ):

    (x, y, z) = batch

    # Obtain loss functions needed for each test
    if model_name == 'tipmi-nc' or model_name == 'tipmi-cf' or model_name == 'sd':
        loss_function = torch.nn.MSELoss()
    elif model_name == 'l2' or model_name == 'mmd' or model_name == 'mm' or model_name == 'mm-cf' or model_name == 'kcit':
        loss_function = torch.nn.BCELoss()
    
    if dataset_name == 'food_review' or dataset_name == 'food_review_double':
        if model_name != 'mm' and model_name != 'mm-cf':
            mask = x['attention_mask'].to(device)
            input_id = x['input_ids'].squeeze(1).to(device)
            y = y.to(device).to(dtype=torch.float32)

        # Calculate the loss
        if model_name == 'l2':
            yhat = model(input_id, mask)
            y = torch.unsqueeze(y, dim=1)
            loss = loss_function(yhat, y)   

        elif model_name == 'tipmi-nc' or model_name == 'tipmi-cf' or model_name == 'sd':
            student_logits = model.forward_logits(input_id, mask)
            student_logits = torch.squeeze(student_logits)
            loss = loss_function(student_logits, y)

        elif model_name == 'gdro':
            z = z.to(device).to(dtype=torch.float32)
            yhat = model(input_id, mask)
            y = torch.unsqueeze(y, dim=1)
            loss = loss_computer.loss(yhat, y, dataset_name, group_idx=z, is_training=True)

        elif model_name == 'mmd':
            z = z.to(device).to(dtype=torch.float32)
            yhat = model(input_id, mask)
            y = torch.unsqueeze(y, dim=1)
            ce_loss = loss_function(yhat, y)
            sigma = mmd_info['sigma']
            sigma = torch.tensor(sigma).to(device)
            alpha = mmd_info['alpha']
            alpha = torch.tensor(alpha).to(device)
            mmd_loss_function = MMD_Loss(sigma)
            mmd_loss = mmd_loss_function(yhat, y, z)
            loss = ce_loss + alpha * mmd_loss
        
        elif model_name == 'mm' or model_name == 'mm-cf':
            summary_tokens = tokenizer(z, max_length=512, truncation=True, padding='max_length', return_tensors='pt')
            summary_masks = summary_tokens.attention_mask.to(device)
            summary_input_ids = summary_tokens.input_ids.squeeze(1).to(device)
            yhat = model(summary_input_ids, summary_masks)
            y = y.to(device).to(dtype=torch.float32)
            y = torch.unsqueeze(y, dim=1)
            loss = loss_function(yhat, y)

        elif model_name == 'irm':
            z = z.to(device).to(dtype=torch.float32)
            yhat = model(input_id, mask)
            y = torch.unsqueeze(y, dim=1)
            z = torch.unsqueeze(z, dim=1)
            loss = loss_computer.loss(yhat, y, z)

        elif model_name == 'kcit':
            mask_med = z['attention_mask'].to(device)
            input_id_med = z['input_ids'].squeeze(1).to(device).to(dtype=torch.float32)
            yhat = model(input_id, mask)
            y = torch.unsqueeze(y, dim=1)
            ce_loss = loss_function(yhat, y)
            sigma = mmd_info['sigma']
            sigma = torch.tensor(sigma).to(device)
            alpha = mmd_info['alpha']
            alpha = torch.tensor(alpha).to(device)
            kcit_loss_function = KCIT_Loss(epsilon=1e-3, sigma=sigma)
            kcit_loss = kcit_loss_function(yhat, y, input_id_med)
            loss = ce_loss + alpha * kcit_loss


    else:
        # Load data onto GPU (if it exists)
        x = x.to(device).to(dtype=torch.float32)
        y = y.to(device).to(dtype=torch.float32)
        z = z.to(device).to(dtype=torch.float32)

        # Calculate the loss
        if model_name == 'l2':
            yhat = model(x)
            y = torch.unsqueeze(y, dim=1)
            loss = loss_function(yhat, y)

        elif model_name == 'tipmi-nc' or model_name == 'tipmi-cf' or model_name == 'sd':
            student_logits = model.forward_logits(x)
            student_logits = torch.squeeze(student_logits)
            loss = loss_function(student_logits, y)

        elif model_name == 'gdro':
            yhat = model(x)
            y = torch.unsqueeze(y, dim=1)
            loss = loss_computer.loss(yhat, y, dataset_name, group_idx=z, is_training=True)                

        elif model_name == 'mmd':
            yhat = model(x)
            y = torch.unsqueeze(y, dim=1)
            ce_loss = loss_function(yhat, y)
            sigma = mmd_info['sigma']
            sigma = torch.tensor(sigma).to(device)
            alpha = mmd_info['alpha']
            alpha = torch.tensor(alpha).to(device)
            mmd_loss_function = MMD_Loss(sigma)
            mmd_loss = mmd_loss_function(yhat, y, z)
            loss = ce_loss + alpha * mmd_loss

        elif model_name == 'mm' or model_name == 'mm-cf':
            yhat = model(z)
            y = torch.unsqueeze(y, dim=1)
            loss = loss_function(yhat, y)

        elif model_name == 'irm':
            yhat = model(x)
            y = torch.unsqueeze(y, dim=1)
            z = torch.unsqueeze(z, dim=1)
            loss = loss_computer.loss(yhat, y, z)

        elif model_name == 'kcit':
            yhat = model(x)
            y = torch.unsqueeze(y, dim=1)
            ce_loss = loss_function(yhat, y)
            sigma = mmd_info['sigma']
            sigma = torch.tensor(sigma).to(device)
            alpha = mmd_info['alpha']
            alpha = torch.tensor(alpha).to(device)
            kcit_loss_function = KCIT_Loss(epsilon=1e-3, sigma=sigma)
            kcit_loss = kcit_loss_function(yhat, y, z)
            loss = ce_loss + alpha * kcit_loss

    return loss



def evaluate(model, data_loader, dataset, mediator_model=None, tokenizer=None, mm_tokenizer=None):

    with torch.no_grad():
        all_yhat, all_y = torch.Tensor().to(device), torch.Tensor().to(device)

        for x, y, _ in data_loader:
            if mediator_model != None:
                if dataset == 'waterbirds' or dataset == 'waterbirds_double':
                    # Get mediator
                    x = x.to(device).to(dtype=torch.float32)
                    logits_mask = mediator_model(x)
                    prob_mask = torch.sigmoid(logits_mask)
                    pred_mask = (prob_mask > 0.5).float()
                    masked_img = x * pred_mask
                    # Make prediction using mediator
                    yhat = model(masked_img)
                    y = y.to(device)
                elif dataset == 'koa' or dataset == 'koa_double':
                    # Get mediator
                    x = x.to(device).to(dtype=torch.float32)
                    zhat = mediator_model(x)
                    # Make prediction using mediator
                    yhat = model(zhat)
                    y = y.to(device)
                elif dataset == 'food_review' or dataset == 'food_review_double':
                    # Get mediator
                    input_tokens = mm_tokenizer(x, max_length=512, truncation=True, padding='max_length', return_tensors='pt')
                    input_ids = input_tokens.input_ids.to(device)
                    attention_mask = input_tokens.attention_mask.to(device)
                    output = mediator_model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=512)
                    summaries = mm_tokenizer.batch_decode(output, skip_special_tokens=True)
                    summary_tokens = tokenizer(summaries, max_length=512, truncation=True, padding='max_length', return_tensors='pt')
                    summary_masks = summary_tokens.attention_mask.to(device)
                    summary_input_ids = summary_tokens.input_ids.squeeze(1).to(device)
                    # Make prediction using mediator
                    yhat = model(summary_input_ids, summary_masks)
                    y = y.to(device).to(dtype=torch.float32)

            else:
                if dataset == 'food_review' or dataset == 'food_review_double':
                    mask = x['attention_mask'].to(device)
                    input_id = x['input_ids'].squeeze(1).to(device)
                    y = y.to(device).to(dtype=torch.float32)
                    yhat = model(input_id, mask)
                else:
                    x = x.to(device)
                    y = y.to(device)
                    yhat = model(x)

            all_yhat = torch.cat((all_yhat, yhat))
            all_y = torch.cat((all_y, y))

        # Return auroc
        auroc = roc_auc_score(all_y.cpu().numpy(), all_yhat.cpu().numpy())

        return auroc



def evaluate_worst_group_accuracy(model, data_loader, group_info, dataset, mediator_model=None, tokenizer=None, mm_tokenizer=None, return_list=False):

    with torch.no_grad():
        all_yhat, all_y, all_group = torch.Tensor().to(device), torch.Tensor().to(device), torch.Tensor().to(device)
        for x, y, group in data_loader:
            
            group = group.to(device)
            if mediator_model != None:
                if dataset == 'waterbirds' or dataset == 'waterbirds_double':
                    # Get mediator
                    x = x.to(device).to(dtype=torch.float32)
                    logits_mask = mediator_model(x)
                    prob_mask = torch.sigmoid(logits_mask)
                    pred_mask = (prob_mask > 0.5).float()
                    masked_img = x * pred_mask
                    # Make prediction using mediator
                    yhat = model(masked_img)
                    y = y.to(device)
                elif dataset == 'koa' or dataset == 'koa_double':
                    # Get mediator
                    x = x.to(device).to(dtype=torch.float32)
                    zhat = mediator_model(x)
                    # Make prediction using mediator
                    yhat = model(zhat)
                    y = y.to(device)
                elif dataset == 'food_review' or dataset == 'food_review_double':
                    # Get mediator
                    input_tokens = mm_tokenizer(x, max_length=512, truncation=True, padding='max_length', return_tensors='pt')
                    input_ids = input_tokens.input_ids.to(device)
                    attention_mask = input_tokens.attention_mask.to(device)
                    output = mediator_model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=512)
                    summaries = mm_tokenizer.batch_decode(output, skip_special_tokens=True)
                    summary_tokens = tokenizer(summaries, max_length=512, truncation=True, padding='max_length', return_tensors='pt')
                    summary_masks = summary_tokens.attention_mask.to(device)
                    summary_input_ids = summary_tokens.input_ids.squeeze(1).to(device)
                    # Make prediction using mediator
                    yhat = model(summary_input_ids, summary_masks)
                    y = y.to(device).to(dtype=torch.float32)

            else:
                if dataset == 'food_review' or dataset == 'food_review_double':
                    mask = x['attention_mask'].to(device)
                    input_id = x['input_ids'].squeeze(1).to(device)
                    y = y.to(device).to(dtype=torch.float32)
                    yhat = model(input_id, mask)
                else:
                    x = x.to(device)
                    y = y.to(device)
                    yhat = model(x)

            all_yhat = torch.cat((all_yhat, yhat))
            all_y = torch.cat((all_y, y))
            all_group = torch.cat((all_group, group))

        # Return accuracy
        accuracy_list = []
        for i in range(group_info['n_groups']):
            group_indices = all_group == i
            yhat = all_yhat[group_indices]
            y = all_y[group_indices]

            group = all_group[group_indices]
            fpr, tpr, thresholds = roc_curve(all_y.cpu().numpy(), all_yhat.cpu().numpy())
            j = tpr - fpr
            index = np.argmax(j)
            threshold = thresholds[index]
            labels = torch.where(yhat >= threshold, 1, 0)
            accuracy = torch.sum((labels.squeeze() == y.squeeze()).long()) / group.shape[0]
            accuracy = accuracy.cpu().numpy()

            accuracy_list.append(accuracy)
        if return_list:
            return accuracy_list
        else:
            worst_group_accuracy = np.nanmin(accuracy_list)
            return worst_group_accuracy



def evaluate_teacher(model, data_loader, dataset):

    with torch.no_grad():
        all_yhat, all_y = torch.Tensor().to(device), torch.Tensor().to(device)

        for _, y, x in data_loader:

            if dataset == 'food_review' or dataset == 'food_review_double':
                mask = x['attention_mask'].to(device)
                input_id = x['input_ids'].squeeze(1).to(device)
                y = y.to(device).to(dtype=torch.float32)
                yhat = model(input_id, mask)
            else:
                x = x.to(device)
                y = y.to(device)
                yhat = model(x)

            all_yhat = torch.cat((all_yhat, yhat))
            all_y = torch.cat((all_y, y))

        # Return auroc
        auroc = roc_auc_score(all_y.cpu().numpy(), all_yhat.cpu().numpy())


        return auroc


def train_teacher_model(dataset, train_loader, lr, l2_weight, dropout, num_epochs, sd=False):

    # Load the model
    if dataset == 'food_review' or dataset == 'food_review_double':
        model = BertClassifier(dropout, dropout)
    elif dataset == 'koa' or dataset == 'koa_double':
        if sd:
            model = BinaryPreTrainedNet('resnet')
        else:
            model = JswNet()
    else:
        model = BinaryPreTrainedNet('resnet')
    model.to(device)
    model.train()

    # Obtain optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_weight)
    loss_function = torch.nn.BCELoss()

    for _ in range(num_epochs):

        for batch in train_loader:
            
            # If SD, do not use privileged information
            if sd:
                (x, y, _) = batch
            else:
                (_, y, x) = batch

            # Load data onto GPU (if it exists) and perform forward pass
            if dataset == 'food_review' or dataset == 'food_review_double':
                mask = x['attention_mask'].to(device)
                input_id = x['input_ids'].squeeze(1).to(device)
                y = y.to(device).to(dtype=torch.float32)
                y = torch.unsqueeze(y,dim=1)
                yhat = model(input_id, mask)
                loss = loss_function(yhat, y)

            elif dataset == 'koa' or dataset == 'koa_double' or dataset == 'waterbirds' or dataset == 'waterbirds_double':
                x = x.to(device).to(dtype=torch.float32)
                y = y.to(device).to(dtype=torch.float32)
                y = torch.unsqueeze(y, dim=1)
                yhat = model(x)
                loss = loss_function(yhat, y)      

            # Perform backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    model.eval()
    return model



def get_teacher_logits(teacher_model, val_loader, dataset, sd=False):

    with torch.no_grad():

        extra_list = []
        teacher_logits_list = []

        for x, y, z, extra in val_loader:

            # If SD, do not use privileged information
            if sd:
                x = x
            else:
                x = z

            # Get the logits
            if dataset == 'food_review' or dataset == 'food_review_double':
                mask = x['attention_mask'].to(device)
                input_id = x['input_ids'].squeeze(1).to(device)
                teacher_logits = teacher_model.forward_logits(input_id, mask)
                teacher_logits = torch.squeeze(teacher_logits)
            else:
                x = x.to(device)
                teacher_logits = teacher_model.forward_logits(x)
                teacher_logits = torch.squeeze(teacher_logits)

            # Save info
            if teacher_logits.numel() > 1: # Prevents bug from occurring when final batch size is 1
                extra_list = extra_list + list(extra)
                teacher_logits_list.append(teacher_logits.cpu())
    
    teacher_logits_list = torch.concat(teacher_logits_list, dim=0)
    
    return extra_list, teacher_logits_list



def train_mediator_model(dataset, train_loader, mm_lr, mm_l2_weight=None, mm_dropout=None, num_epochs=None, tokenizer=None, batch_size=32):

    # Load the model
    if dataset == 'koa' or dataset == 'koa_double':
        model = JswNetMediator('resnet')
    elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
        model = smp.Unet(
            encoder_name="resnet50",
            encoder_weights='imagenet',
            in_channels=3,          
            classes=1,
        ) 
    elif dataset == 'food_review' or dataset == 'food_review_double':
        model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
    model.to(device)
    model.train()

    if dataset == 'food_review' or dataset == 'food_review_double':
        args = Seq2SeqTrainingArguments(
            output_dir='',
            learning_rate=mm_lr,
            weight_decay=mm_l2_weight,
            per_device_train_batch_size=batch_size,
            num_train_epochs=num_epochs,
            save_total_limit=1,
            eval_accumulation_steps=1,
            per_device_eval_batch_size=1,
            predict_with_generate=True
        )
        data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

        trainer = Seq2SeqTrainer(
            model,
            args,
            train_dataset=train_loader,
            data_collator=data_collator,
            tokenizer=tokenizer
        )

        trainer.train()

    else:

        # Obtain optimizer and loss
        optimizer = torch.optim.Adam(model.parameters(), lr=mm_lr, weight_decay=mm_l2_weight)
        if dataset == 'koa' or dataset == 'koa_double':
            loss_function = torch.nn.MSELoss()
        elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
            loss_function = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

        for _ in range(num_epochs):

            for batch in train_loader:
                
                (x, y, z) = batch

                # Load data onto GPU (if it exists) and perform forward pass
                if dataset == 'koa' or dataset == 'koa_double':
                    x = x.to(device).to(dtype=torch.float32)
                    z = z.to(device).to(dtype=torch.float32)       
                    zhat = model(x)
                    loss = loss_function(zhat, z)

                elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
                    x = x.to(device).to(dtype=torch.float32)
                    z = z.to(device).to(dtype=torch.float32)
                    logits_mask = model(x)
                    loss = loss_function(logits_mask, z)

                # Perform backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    model.eval()
    return model


def evaluate_mediator_model(model, data_loader, dataset, mm_tokenizer=None, batch_size=32):

    # Obtain optimizer and loss
    if dataset == 'koa' or dataset == 'koa_double':
        loss_function = torch.nn.MSELoss()
    elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
        loss_function = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

    loss = 0
    num_samples = 0

    with torch.no_grad():

        for x, y, z in data_loader:
            num_samples += len(z)

            if dataset == 'koa' or dataset == 'koa_double':
                x = x.to(device).to(dtype=torch.float32)
                z = z.to(device).to(dtype=torch.float32)
                zhat = model(x)
                loss += loss_function(zhat, z) * len(z)

            elif dataset == 'waterbirds' or dataset == 'waterbirds_double':
                x = x.to(device).to(dtype=torch.float32)
                z = z.to(device).to(dtype=torch.float32)
                logits_mask = model(x)
                loss += loss_function(logits_mask, z) * len(z)

            elif dataset == 'food_review' or dataset == 'food_review_double':
                rouge = evaluate_rogue.load('rouge')
                input_tokens = mm_tokenizer(x, max_length=512, truncation=True, padding='max_length', return_tensors='pt')
                input_ids = input_tokens.input_ids.to(device)
                attention_mask = input_tokens.attention_mask.to(device)
                output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=512)
                summaries = mm_tokenizer.batch_decode(output, skip_special_tokens=True)
                result = rouge.compute(predictions=summaries, references=z)
                loss += result['rouge1'] * len(z)
      
        # Return mean_loss
        mean_loss = loss / num_samples
        if dataset == 'food_review' or dataset == 'food_review_double':
            return mean_loss
        else:
            mean_loss = mean_loss.cpu().detach().numpy()
            return mean_loss



def get_predicted_mediators(mediator_model, val_loader, dataset, mm_tokenizer=None):

    with torch.no_grad():

        extra_list = []
        mediators_list = []
        labels_list = []

        for x, y, z, extra in val_loader:

            # Get the mediators
            with torch.no_grad():
                if dataset == 'waterbirds' or dataset == 'waterbirds_double':
                    # Get mediator
                    x = x.to(device).to(dtype=torch.float32)
                    logits_mask = mediator_model(x)
                    prob_mask = torch.sigmoid(logits_mask)
                    pred_mask = (prob_mask > 0.5).float()
                    mediators = x * pred_mask
                elif dataset == 'koa' or dataset == 'koa_double':
                    x = x.to(device).to(dtype=torch.float32)
                    z = z.to(device).to(dtype=torch.float32)       
                    mediators = mediator_model(x)

            # Save info
            if (dataset == 'food_review' or dataset == 'food_review_double') and len(z) > 1: # Prevents bug from occurring when final batch size is 1
                extra_list = extra_list + list(extra)
                labels_list.append(y.cpu())
                mediators_list = mediators_list + list(z)
            elif z.numel() > 1: # Prevents bug from occurring when final batch size is 1
                extra_list = extra_list + list(extra)
                labels_list.append(y.cpu())
                mediators_list.append(z.cpu())
    
    if dataset != 'food_review' and dataset != 'food_review_double':
        mediators_list = torch.concat(mediators_list, dim=0)
    labels_list = torch.concat(labels_list, dim=0)
    
    return extra_list, mediators_list, labels_list