from model.backbone import FeatureExtractor, TransformerEncoder, SleepMLP
from model.backbone import FeatureExtractor_Face, TransformerEncoder_Face, EmotionMLP_Face
from model.backbone import FeatureExtractor_BCI2000, TransformerEncoder_BCI2000, MotorImageryBCI2000
from model.backbone import FeatureExtractor_TUEV, TransformerEncoder_TUEV, TUEVMLP_TUEV
import torch.nn as nn
import torch
from sklearn.metrics import classification_report, f1_score, confusion_matrix
from utils.util import ModelConfig
import os


def pretraining(train_dl, val_dl, args):
    if args.dataset == 'ISRUC':
        multi_modality_model = train_block(train_dl, val_dl, args)
    elif args.dataset == 'FACED':
        multi_modality_model = face_train_block(train_dl, val_dl, args)
    elif args.dataset == 'BCI2000':
        multi_modality_model = BCI2000_train_block(train_dl, val_dl, args)
    else:
        multi_modality_model = None

    state_f = multi_modality_model[0].state_dict()
    for key in state_f.keys():
        state_f[key] = state_f[key].to(torch.device("cpu"))

    state_encoder = multi_modality_model[1].state_dict()
    for key in state_encoder.keys():
        state_encoder[key] = state_encoder[key].to(torch.device("cpu"))

    state_sleep = multi_modality_model[2].state_dict()
    for key in state_sleep.keys():
        state_sleep[key] = state_sleep[key].to(torch.device("cpu"))

    if not os.path.exists(f"Pretrain/{args.source_percent}/"):
        os.makedirs(f"Pretrain/{args.source_percent}/")

    torch.save(state_f,
               f"feature_extractor_parameter_{args.seed}.pkl")
    torch.save(state_encoder,
               f"feature_encoder_parameter_{args.seed}.pkl")
    torch.save(state_sleep,
               f"sleep_classifier_parameter_{args.seed}.pkl")


def train_block(train_dl, val_dl, args):
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    total_acc = []
    total_f1 = []
    best_acc = 0
    best_f1 = 0
    best_epoch = 0

    feature_extractor = FeatureExtractor(args).float().to(device)
    sleep_classifier = SleepMLP(args).float().to(device)
    feature_encoder = TransformerEncoder(args).float().to(device)

    # loss function
    classifier_criterion = nn.CrossEntropyLoss().to(device)

    # optimizer
    optimizer_encoder = torch.optim.Adam(list(feature_extractor.parameters())
                                         + list(sleep_classifier.parameters())
                                         + list(feature_encoder.parameters()), lr=args.lr,
                                         betas=(args.beta1, args.beta2),
                                         weight_decay=args.weight_decay)

    model_param = ModelConfig(args.dataset)

    for epoch in range(1, args.pretrain_epoch + 1):
        print(f"--------------------Epoch{epoch}---Pretrain-----------------------------")
        feature_extractor.train()
        sleep_classifier.train()
        feature_encoder.train()
        running_loss = 0.0
        for batch_idx, data in enumerate(train_dl):
            x, _, label = data[0].to(device), data[1].to(device), data[2].to(device)
            epoch_size = model_param.EpochLength
            eog = x[:, :, :2, :]
            eeg = x[:, :, 2:, :]
            eog = eog.view(-1, model_param.EogNum, epoch_size)
            eeg = eeg.view(-1, model_param.EegNum, epoch_size)
            eeg_eog_feature = feature_extractor(eeg, eog)
            eeg_eog_feature = feature_encoder(eeg_eog_feature)
            pred = sleep_classifier(eeg_eog_feature)

            loss_classifier = classifier_criterion(pred, label.long())

            optimizer_encoder.zero_grad()
            loss_classifier.backward()
            torch.nn.utils.clip_grad_norm_(sleep_classifier.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(feature_extractor.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 1.0)
            optimizer_encoder.step()

            running_loss += loss_classifier.item()

            if batch_idx % 10 == 9:
                print('\n [%d,  %5d] total_loss: %.3f ' % (epoch, batch_idx + 1, running_loss / 10))
                running_loss = 0.0

        if epoch % 1 == 0:
            print(f" -------------------Epoch{epoch}---Val------------------------------")
            report = dev_block((feature_extractor, feature_encoder, sleep_classifier),
                               val_dl, args, model_param)
            total_acc.append(report[0])
            total_f1.append(report[1])

        if total_acc[-1] > best_acc:
            best_acc = total_acc[-1]
            best_f1 = total_f1[-1]
            best_epoch = epoch
            best_modality_feature = (feature_extractor, feature_encoder, sleep_classifier)
        print("dev_acc:", total_acc)
        print("dev_macro_f1:", total_f1)
    else:
        print(f"Pretraining: Best Epoch:{best_epoch}  Best ACC:{best_acc}  Best F1:{best_f1}")

    return best_modality_feature


def dev_block(model, val_dl, args, model_param):
    if type(model) == tuple:
        model[0].eval()
        model[1].eval()
        model[2].eval()
    else:
        model.eval()

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()

    y_pred = []
    y_test = []
    with torch.no_grad():
        correct = 0.0
        total = 0.0
        dev_mean_loss = 0.0
        for batch_idx, data in enumerate(val_dl):
            x, _, labels = data[0].to(device), data[1].to(device), data[2].to(device)
            epoch_size = model_param.EpochLength
            eog = x[:, :, :2, :]
            eeg = x[:, :, 2:, :]
            eog = eog.view(-1, model_param.EogNum, epoch_size)
            eeg = eeg.view(-1, model_param.EegNum, epoch_size)
            eeg_eog_feature = model[0](eeg, eog)
            eeg_eog_feature = model[1](eeg_eog_feature)
            prediction = model[2](eeg_eog_feature)

            dev_loss = criterion(prediction, labels.long())
            dev_mean_loss += dev_loss.item()

            _, predicted = torch.max(prediction.data, dim=1)
            predicted, labels = torch.flatten(predicted), torch.flatten(labels)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            acc = correct / total
            count = batch_idx
            predicted = predicted.tolist()
            y_pred.extend(predicted)
            labels = labels.tolist()
            y_test.extend(labels)

        macro_f1 = f1_score(y_test, y_pred, average="macro")
        print('dev loss:', dev_mean_loss / count, 'Accuracy on sleep:', acc, 'F1 score on sleep:', macro_f1, )
        print(classification_report(y_test, y_pred, target_names=['Sleep stage W',
                                                                  'Sleep stage 1',
                                                                  'Sleep stage 2',
                                                                  'Sleep stage 3/4',
                                                                  'Sleep stage R']))
        confusion_mtx = confusion_matrix(y_test, y_pred)
        print(confusion_mtx)

        report = (acc, macro_f1)
        return report


def face_train_block(train_dl, val_dl, args):
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    total_acc = []
    total_f1 = []
    best_acc = 0
    best_f1 = 0
    best_epoch = 0

    feature_extractor = FeatureExtractor_Face(args).float().to(device)
    emotion_classifier = EmotionMLP_Face(args).float().to(device)
    feature_encoder = TransformerEncoder_Face(args).float().to(device)

    # loss function
    classifier_criterion = nn.CrossEntropyLoss().to(device)

    # optimizer
    optimizer_encoder = torch.optim.Adam(list(feature_extractor.parameters())
                                         + list(emotion_classifier.parameters())
                                         + list(feature_encoder.parameters()), lr=args.lr,
                                         betas=(args.beta1, args.beta2),
                                         weight_decay=args.weight_decay)

    model_param = ModelConfig(args.dataset)

    for epoch in range(1, args.pretrain_epoch + 1):
        print(f"--------------------Epoch{epoch}---Pretrain-----------------------------")
        feature_extractor.train()
        emotion_classifier.train()
        feature_encoder.train()
        running_loss = 0.0
        for batch_idx, data in enumerate(train_dl):
            x, _, label = data[0].to(device), data[1].to(device), data[2].to(device)

            ff = feature_extractor(x)
            ff = feature_encoder(ff)
            pred = emotion_classifier(ff)

            loss_classifier = classifier_criterion(pred, label.long())

            optimizer_encoder.zero_grad()
            loss_classifier.backward()
            torch.nn.utils.clip_grad_norm_(emotion_classifier.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(feature_extractor.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 1.0)
            optimizer_encoder.step()

            running_loss += loss_classifier.item()

            if batch_idx % 3 == 0:
                print('\n [%d,  %5d] total_loss: %.3f ' % (epoch, batch_idx + 1, running_loss / 10))
                running_loss = 0.0

        if epoch % 1 == 0:
            print(f" -------------------Epoch{epoch}---Val------------------------------")
            report = face_dev_block((feature_extractor, feature_encoder, emotion_classifier),
                               val_dl, args, model_param)
            total_acc.append(report[0])
            total_f1.append(report[1])

        if total_acc[-1] > best_acc:
            best_acc = total_acc[-1]
            best_f1 = total_f1[-1]
            best_epoch = epoch
            best_modality_feature = (feature_extractor, feature_encoder, emotion_classifier)
        print("dev_acc:", total_acc)
        print("dev_macro_f1:", total_f1)
    else:
        print(f"Pretraining: Best Epoch:{best_epoch}  Best ACC:{best_acc}  Best F1:{best_f1}")

    return best_modality_feature


def face_dev_block(model, val_dl, args, model_param):
    if type(model) == tuple:
        model[0].eval()
        model[1].eval()
        model[2].eval()
    else:
        model.eval()

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()

    y_pred = []
    y_test = []
    with torch.no_grad():
        correct = 0.0
        total = 0.0
        dev_mean_loss = 0.0
        for batch_idx, data in enumerate(val_dl):
            x, _, labels = data[0].to(device), data[1].to(device), data[2].to(device)
            ff = model[0](x)
            ff = model[1](ff)
            prediction = model[2](ff)

            dev_loss = criterion(prediction, labels.long())
            dev_mean_loss += dev_loss.item()

            _, predicted = torch.max(prediction.data, dim=1)
            predicted, labels = torch.flatten(predicted), torch.flatten(labels)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            acc = correct / total
            count = batch_idx
            predicted = predicted.tolist()
            y_pred.extend(predicted)
            labels = labels.tolist()
            y_test.extend(labels)

        macro_f1 = f1_score(y_test, y_pred, average="macro")
        print('dev loss:', dev_mean_loss / count, 'Accuracy on sleep:', acc, 'F1 score on sleep:', macro_f1, )
        print(classification_report(y_test, y_pred, target_names=model_param.ClassNamesFace))
        confusion_mtx = confusion_matrix(y_test, y_pred)
        print(confusion_mtx)

        report = (acc, macro_f1)
        return report


def BCI2000_train_block(train_dl, val_dl, args):
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    total_acc = []
    total_f1 = []
    best_acc = 0
    best_f1 = 0
    best_epoch = 0

    feature_extractor = FeatureExtractor_BCI2000(args).float().to(device)
    emotion_classifier = MotorImageryBCI2000(args).float().to(device)
    feature_encoder = TransformerEncoder_BCI2000(args).float().to(device)

    # loss function
    classifier_criterion = nn.CrossEntropyLoss().to(device)

    # optimizer
    optimizer_encoder = torch.optim.Adam(list(feature_extractor.parameters())
                                         + list(emotion_classifier.parameters())
                                         + list(feature_encoder.parameters()), lr=args.lr,
                                         betas=(args.beta1, args.beta2),
                                         weight_decay=args.weight_decay)

    model_param = ModelConfig(args.dataset)

    for epoch in range(1, args.pretrain_epoch + 1):
        print(f"--------------------Epoch{epoch}---Pretrain-----------------------------")
        feature_extractor.train()
        emotion_classifier.train()
        feature_encoder.train()
        running_loss = 0.0
        for batch_idx, data in enumerate(train_dl):
            x, _, label = data[0].to(device), data[1].to(device), data[2].to(device)

            ff = feature_extractor(x)
            ff = feature_encoder(ff)
            pred = emotion_classifier(ff)

            loss_classifier = classifier_criterion(pred, label.long())

            optimizer_encoder.zero_grad()
            loss_classifier.backward()
            torch.nn.utils.clip_grad_norm_(emotion_classifier.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(feature_extractor.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 1.0)
            optimizer_encoder.step()

            running_loss += loss_classifier.item()

            if batch_idx % 3 == 0:
                print('\n [%d,  %5d] total_loss: %.3f ' % (epoch, batch_idx + 1, running_loss / 10))
                running_loss = 0.0

        if epoch % 1 == 0:
            print(f" -------------------Epoch{epoch}---Val------------------------------")
            report = BCI2000_dev_block((feature_extractor, feature_encoder, emotion_classifier),
                               val_dl, args, model_param)
            total_acc.append(report[0])
            total_f1.append(report[1])

        if total_acc[-1] > best_acc:
            best_acc = total_acc[-1]
            best_f1 = total_f1[-1]
            best_epoch = epoch
            best_modality_feature = (feature_extractor, feature_encoder, emotion_classifier)
        print("dev_acc:", total_acc)
        print("dev_macro_f1:", total_f1)
    else:
        print(f"Pretraining: Best Epoch:{best_epoch}  Best ACC:{best_acc}  Best F1:{best_f1}")

    return best_modality_feature


def BCI2000_dev_block(model, val_dl, args, model_param):
    if type(model) == tuple:
        model[0].eval()
        model[1].eval()
        model[2].eval()
    else:
        model.eval()

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()

    y_pred = []
    y_test = []
    with torch.no_grad():
        correct = 0.0
        total = 0.0
        dev_mean_loss = 0.0
        for batch_idx, data in enumerate(val_dl):
            x, _, labels = data[0].to(device), data[1].to(device), data[2].to(device)
            ff = model[0](x)
            ff = model[1](ff)
            prediction = model[2](ff)

            dev_loss = criterion(prediction, labels.long())
            dev_mean_loss += dev_loss.item()

            _, predicted = torch.max(prediction.data, dim=1)
            predicted, labels = torch.flatten(predicted), torch.flatten(labels)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            acc = correct / total
            count = batch_idx
            predicted = predicted.tolist()
            y_pred.extend(predicted)
            labels = labels.tolist()
            y_test.extend(labels)

        macro_f1 = f1_score(y_test, y_pred, average="macro")
        print('dev loss:', dev_mean_loss / count, 'Accuracy on MI:', acc, 'F1 score on MI:', macro_f1, )
        print(classification_report(y_test, y_pred, target_names=model_param.ClassNamesBCI2000))
        confusion_mtx = confusion_matrix(y_test, y_pred)
        print(confusion_mtx)

        report = (acc, macro_f1)
        return report

