from dataset import domain_dataset
from torch.utils.data import DataLoader
from model import CNN_baseline, Linear_baseline
import tqdm
import torch
import config as cfg
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
import os
from sklearn.manifold import TSNE
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

writer = SummaryWriter()


# train the model for every subject
def train_valid_model(eegdata, eeglabel,eegdomain, saveckpt):

# ----------------------initial model------------------------
    valid_loss_min = 100

    # model_xz is used to learn p(z|x)
    # model_zy is used to learn p(y|z)
    model_xz = CNN_baseline().to(cfg.device)
    model_zy = Linear_baseline().to(cfg.device)

    # Train and test using the current folded data
    x_train_val, y_train_val,z_train_val = eegdata, eeglabel,eegdomain
    x_train, x_valid, y_train, y_valid, z_train, z_valid = train_test_split(x_train_val, y_train_val,z_train_val, test_size=0.2, random_state=2024)


    # get the dataset
    train_dataset = domain_dataset(x_train, y_train,z_train)
    valid_dataset = domain_dataset(x_valid, y_valid,z_valid)

    train_loader = DataLoader(dataset=train_dataset, batch_size=cfg.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=cfg.batch_size, shuffle=True)


    # set the criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_xz = torch.optim.AdamW(model_xz.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    optimizer_zy = torch.optim.AdamW(model_zy.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)


# ---------------------train and valid-----------

    for epoch in range(cfg.epoch_num):

        # train the model
        num_correct_xz = 0
        num_samples_xz = 0
        num_correct_zy = 0
        num_samples_zy = 0

        # ---------------------train---------------------
        for iter, (eeg, label,domain) in enumerate(tqdm.tqdm(train_loader, position=0, leave=True), start=1):
            running_loss = 0.0
            # get the input
            eeg = eeg.to(cfg.device)
            domain = domain.to(cfg.device)

            pred,state = model_xz(eeg)
            loss = criterion(pred, domain)

            # backward
            optimizer_xz.zero_grad()
            loss.backward()

            # gradient descent or adam step
            optimizer_xz.step()

            _, predictions_xz = pred.max(1)
            num_correct_xz += (predictions_xz == domain).sum()
            num_samples_xz += predictions_xz.size(0)

            label = label.to(cfg.device)
            # state without grad
            state = state.detach()
            pred = model_zy(state)
            loss = criterion(pred, label)
            optimizer_zy.zero_grad()
            loss.backward()
            optimizer_zy.step()
            _, predictions_zy = pred.max(1)
            num_correct_zy += (predictions_zy == label).sum()
            num_samples_zy += predictions_zy.size(0)

        decoder_answer_xz = float(num_correct_xz) / float(num_samples_xz) * 100
        decoder_answer_zy = float(num_correct_zy) / float(num_samples_zy) * 100
        print(f"saveckpt: {saveckpt},epoch: {epoch},train loss_xz: {loss}, train_decoder_answer_xz: {decoder_answer_xz}%\n")
        print(f"saveckpt: {saveckpt},epoch: {epoch},train loss_zy: {loss}, train_decoder_answer_zy: {decoder_answer_zy}%\n")


        # ---------------------valid---------------------
        num_correct_xz = 0
        num_samples_xz = 0
        num_correct_zy = 0
        num_samples_zy = 0
        valid_loss_xz = 0
        valid_loss_zy = 0
        valid_loss_min_xz = 100
        valid_loss_min_zy = 100


        for iter, (eeg, label, domain) in enumerate(tqdm.tqdm(valid_loader, position=0, leave=True), start=1):
            with torch.no_grad():
                running_loss = 0.0
                # get the input
                eeg = eeg.to(cfg.device)
                domain = domain.to(cfg.device)

                pred, state = model_xz(eeg)
                loss = criterion(pred, domain)
                valid_loss_xz += loss

                _, predictions_xz = pred.max(1)
                num_correct_xz += (predictions_xz == domain).sum()
                num_samples_xz += predictions_xz.size(0)

                label = label.to(cfg.device)
                pred = model_zy(state)
                loss = criterion(pred, label)
                valid_loss_zy += loss


                _, predictions_zy = pred.max(1)
                num_correct_zy += (predictions_zy == label).sum()
                num_samples_zy += predictions_zy.size(0)

            decoder_answer_xz = float(num_correct_xz) / float(num_samples_xz) * 100
            decoder_answer_zy = float(num_correct_zy) / float(num_samples_zy) * 100

        # Record the results of training
        print(f"saveckpt: {saveckpt},epoch: {epoch},\n"
                f"valid loss_xz: {valid_loss_xz / iter}, valid_decoder_answer_xz: {decoder_answer_xz}%,\n"
                f"valid loss_zy: {valid_loss_zy / iter}, valid_decoder_answer_zy: {decoder_answer_zy}%\n")

        # save the model
        if valid_loss_min_xz > valid_loss_xz / iter:
            valid_loss_min_xz = valid_loss_xz / iter
            saveckpt_xz = saveckpt[:-5] + '_xz.ckpt'
            torch.save(model_xz.state_dict(), saveckpt_xz)

        if valid_loss_min_zy > valid_loss_zy / iter:
            valid_loss_min_zy = valid_loss_zy / iter
            saveckpt_zy = saveckpt[:-5] + '_zy.ckpt'
            torch.save(model_zy.state_dict(), saveckpt_zy)


def test_model(eegdata, eeglabel, eegdomain, saveckpt):

# ----------------------initial model------------------------

    model_xz = CNN_baseline().to(cfg.device)
    model_zy = Linear_baseline().to(cfg.device)
    # test using the current folded data
    x_test, y_test, z_test = eegdata, eeglabel, eegdomain
    # tough the train and valid process exist difference
    # the test_data is same,one second by one second
    test_dataset = domain_dataset(x_test, y_test, z_test)
    # test the data one by one
    test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)


# -------------------------test--------------------------------------------
    # after some epochs, test model

    saveckpt_xz = saveckpt[:-5] + '_xz.ckpt'
    saveckpt_zy = saveckpt[:-5] + '_zy.ckpt'
    test_acc_xz = 0
    test_acc_zy = 0
    model_xz.load_state_dict(torch.load(saveckpt_xz))
    model_zy.load_state_dict(torch.load(saveckpt_zy))

    total_num = 0
    all_state = []
    all_domain = []
    for iter, (eeg, label, domain) in enumerate(tqdm.tqdm(test_loader, position=0, leave=True), start=1):
        with torch.no_grad():

            eeg = eeg.to(cfg.device)
            domain = domain.to(cfg.device)
            pred,state = model_xz(eeg)
            all_state.append(state)
            all_domain.append(domain)
            _, predictions = pred.max(1)

            if predictions == domain:
                test_acc_xz += 1
            total_num = total_num + 1

            label = label.to(cfg.device)
            pred = model_zy(state)
            _, predictions = pred.max(1)
            if predictions == label:
                test_acc_zy += 1

    res_xz = 100 * test_acc_xz / total_num
    res_zy = 100 * test_acc_zy / total_num
    # Object of tSNE
    all_state = torch.cat(all_state, dim=0)
    targets = torch.cat(all_domain, dim=0)
    # to np
    all_state = all_state.cpu().numpy()
    targets = targets.cpu().numpy()
    tsne = TSNE(n_components=2, random_state=2024)
    x_transformed = tsne.fit_transform(all_state)
    tsne_df = pd.DataFrame(np.column_stack((x_transformed, targets)), columns=['X', 'Y', "Targets"])
    tsne_df.loc[:, "Targets"] = tsne_df.Targets.astype(int)
    plt.figure(figsize=(16, 12))
    g = sns.FacetGrid(data=tsne_df, hue='Targets', height=10)
    g.map(plt.scatter, 'X', 'Y').add_legend()
    # plt.show()
    # save the figure


    os.makedirs('./tsne_test_xzy/', exist_ok=True)
    savedir = saveckpt[2:-5]
    # replace / to _
    savedir = savedir.replace('/', '_')
    plt.savefig('./tsne_test_xzy/' + savedir + ".png")
    plt.close()

    print(f"saveckpt: {saveckpt}, test accuracy_xz: %.3f %%, test accuracy_zy: %.3f %%" % (res_xz, res_zy))


    return res_xz, res_zy