import argparse
import os
import numpy as np
import time
import logging

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, accuracy_score, f1_score, roc_curve, auc

from dataloader import AudioVisualDataset, af_collate_fn
from cl_model import Fusion, CognitiveLoadFeatureExtractor
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        CE_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * CE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

def parse_options():
    parser = argparse.ArgumentParser(description="WILTY baseline repo")
    parser.add_argument('--device', type=str, default="cuda:0", help='the gpu id used for predict')
    parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate')
    parser.add_argument('--batch_size', type=int, default=16, help='initial batchsize')
    parser.add_argument('--num_epochs', type=int, default=20, help='total training epochs')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--multi', action='store_true', help="multitask learning with multiple losses")
    parser.add_argument('--num_encoders', type=int, default=4, help="number of transformer encoders for each modality")
    parser.add_argument('--adapter', action='store_true', help="indicator of using adapter")
    parser.add_argument('--adapter_type', type=str, default='efficient_conv', help='adapter type: nlp, efficient_conv')
    parser.add_argument('--log', type=str, default="Final_logs", help='log and save model name')
    parser.add_argument('--protocols', type=list, default=[['train_fold_1.csv', 'test_fold_1.csv'], ['train_fold_2.csv', 'test_fold_2.csv'], ['train_fold_3.csv', 'test_fold_3.csv']], help='protocols for train/test')
    parser.add_argument('--model_name', type=str, default='DOLOS_')
    parser.add_argument('--model_to_train', type=str, default='fusion', help='model to train: audio, vision, fusion')
    parser.add_argument('--fusion_type', type=str, default='cross2', help='fusion type: concat, cross2')
    
    parser.add_argument('--pretrained_cognitive_model_paths_1', type=str, default='mental_demand.pth')
    parser.add_argument('--pretrained_cognitive_model_paths_2', type=str, default='effort_half.pth')
    parser.add_argument('--pretrained_cognitive_model_paths_3', type=str, default='temporal_demand_full.pth')

    opts = parser.parse_args()
    torch.manual_seed(opts.seed)
    opts.device = torch.device(opts.device)

    if opts.adapter:
        opts.model_name = opts.model_name + opts.model_to_train + "_Encoders_" + str(opts.num_encoders) + "_Adapter_" + str(opts.adapter) + "_type_" + str(opts.adapter_type)
    else:
        opts.model_name = opts.model_name + opts.model_to_train + "_Encoders_" + str(opts.num_encoders) + "_Adapter_" + str(opts.adapter)
    if not os.path.exists(opts.log):
        os.makedirs(opts.log)

    return opts

def train_one_epoch(args, train_data_loader, model, optimizer, loss_fn, scheduler):
    epoch_loss = []
    epoch_predictions = []
    epoch_labels = []
    start_time = time.time()

    model.train()
    for i, (waves, faces, labels) in enumerate(train_data_loader):
        waves = waves.squeeze(1).to(args.device)
        faces = faces.to(args.device)
        labels = labels.to(args.device)

        optimizer.zero_grad()
        preds, _, _ = model(waves, faces)
        
        _loss = loss_fn(preds, labels)
        loss = _loss.item()
        epoch_loss.append(loss)

        _loss.backward()
        optimizer.step()

        epoch_predictions.append(torch.argmax(preds, dim=1))
        epoch_labels.append(labels)

        if i % 10 == 0:
            print("iter {}, loss {:.5f}".format(str(i), loss))

    scheduler.step()

    epoch_predictions = torch.cat(epoch_predictions)
    epoch_labels = torch.cat(epoch_labels)
    epoch_loss = np.mean(epoch_loss)
    return epoch_loss, epoch_predictions, epoch_labels

def val_one_epoch(args, val_data_loader, model, loss_fn):
    epoch_loss = []
    epoch_predictions = []
    epoch_labels = []
    start_time = time.time()

    model.eval()
    with torch.no_grad():
        for waves, faces, labels in val_data_loader:
            waves = waves.squeeze(1).to(args.device)
            faces = faces.to(args.device)
            labels = labels.to(args.device)

            preds, _, _ = model(waves, faces)

            _loss = loss_fn(preds, labels)
            loss = _loss.item()
            epoch_loss.append(loss)

            epoch_predictions.append(torch.argmax(preds, dim=1))
            epoch_labels.append(labels)

    epoch_predictions = torch.cat(epoch_predictions)
    epoch_labels = torch.cat(epoch_labels)
    epoch_loss = np.mean(epoch_loss)
    return epoch_loss, epoch_predictions, epoch_labels

def evaluation(labels, preds):
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    fpr, tpr, thresholds = roc_curve(labels, preds, pos_label=1)
    auc_score = auc(fpr, tpr)
    return acc, f1, auc_score

def train_test(log_name, args):
    f = open(log_name, 'a')
    for P in args.protocols:
        print("\n\nCurrent protocol.....................", P)
        train, test = P

        f.write("\n\nTrain file = " + train.split('.')[0])
        f.write("\nTest file = " + test.split('.')[0])

        print("\t Dataset Loaded")

        cognitive_model_1 = CognitiveLoadFeatureExtractor(args.pretrained_cognitive_model_paths_1).to(args.device)
        cognitive_model_2 = CognitiveLoadFeatureExtractor(args.pretrained_cognitive_model_paths_2).to(args.device)
        cognitive_model_3 = CognitiveLoadFeatureExtractor(args.pretrained_cognitive_model_paths_3).to(args.device)
        
        model = Fusion(args.fusion_type, args.num_encoders, args.adapter, args.adapter_type, args.multi).to(args.device)
        
        model.cognitive_feature_extractor_a = cognitive_model_1
        model.cognitive_feature_extractor_b = cognitive_model_2
        model.cognitive_feature_extractor_c = cognitive_model_3
        
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
        loss_fn = nn.CrossEntropyLoss()
        best_acc = 0.0
        results = ""
        val_results = ""

        print("\t Started Training")
        for epoch in range(args.num_epochs):
            if (epoch + 1) % 5 == 0:
                print('\t\t Epoch....', epoch + 1)

            loss, preds, labels = train_one_epoch(args, None, model, optimizer, loss_fn, scheduler)
            train_acc, train_f1, train_auc = evaluation(labels.cpu().numpy(), preds.cpu().numpy())
            val_loss, val_preds, val_labels = val_one_epoch(args, None, model, loss_fn)
            val_acc, val_f1, val_auc = evaluation(val_labels.cpu().numpy(), val_preds.cpu().numpy())

            print("epoch {}, train_acc {:.5f}, train_f1: {:.5f}, train_auc:{:.5f} "
                  "test_acc {:.5f}, test_f1: {:.5f}, test_auc:{:.5f}".format(epoch, train_acc, train_f1, train_auc,
                                                                             val_acc, val_f1, val_auc))

            f.write("epoch {}, train_acc {:.5f}, train_f1: {:.5f}, train_auc:{:.5f} "
                    "test_acc {:.5f}, test_f1: {:.5f}, test_auc:{:.5f}\n".format(epoch, train_acc, train_f1, train_auc,
                                                                               val_acc, val_f1, val_auc))

            if val_acc > best_acc:
                best_acc = val_acc
                results = "best results are acc {:.5f}, f1: {:.5f}, auc:{:.5f} ".format(val_acc, val_f1, val_auc)
                val_results = classification_report(val_labels.cpu().numpy(), val_preds.cpu().numpy(), target_names=["truth", "deception"])
        print("results:\n\n")
        print(results)
        f.write("****************\n")
        f.write(results)
        f.write("\n\n")
        f.write(val_results)
        f.write("\n\n")

    f.close()

if __name__ == "__main__":
    opts = parse_options()

    log_name = str('0927_crossentropy_threeconcat.txt')
    with open(log_name, 'w') as f:
        f.write("\nOptimizer and LR = Adam, " + str(opts.lr))
        f.write("\nBatch Size = " + str(opts.batch_size))
        f.write("\nEpochs = " + str(opts.num_epochs))
        f.write("\nNum Encoders = " + str(opts.num_encoders))
        f.write("\nAdapter = " + str(opts.adapter))
        if opts.adapter:
            f.write("\nAdapter Type = " + opts.adapter_type)
        f.write("\n------------------------------------------")
        f.write("\n------------------------------------------")
    f.close()

    train_test(log_name, args=opts)
