from dataset import end2end_dataset
from torch.utils.data import DataLoader
from model import CNN_baseline
import tqdm
import torch
import config as cfg
import torch.nn as nn
from sklearn.model_selection import KFold,train_test_split
from torch.utils.tensorboard import SummaryWriter
import os
from sklearn.manifold import TSNE
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

writer = SummaryWriter()

# train the model for every subject
def train_valid_model(x_train, y_train, x_valid, y_valid, saveckpt):

# ----------------------initial model------------------------
    valid_loss_min = 100
    model = CNN_baseline().to(cfg.device)


    # get the dataset
    train_dataset = end2end_dataset(x_train, y_train)
    valid_dataset = end2end_dataset(x_valid, y_valid)

    train_loader = DataLoader(dataset=train_dataset, batch_size=cfg.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=cfg.batch_size, shuffle=True)

    # set the criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

# ---------------------train and valid-----------
    train_last_epoch_loss = 0
    train_last_epoch_decoding_answer = 0
    valid_last_epoch_loss = 0
    valid_last_epoch_decoding_answer = 0

    for epoch in range(cfg.epoch_num):

        # train the model
        num_correct = 0
        num_samples = 0
        train_loss = 0



        # ---------------------train---------------------
        for iter, (eeg, label) in enumerate(tqdm.tqdm(train_loader, position=0, leave=True), start=1):
            running_loss = 0.0
            # get the input
            eeg = eeg.to(cfg.device)
            label = label.to(cfg.device)

            pred,_ = model(eeg)
            loss = criterion(pred, label)
            train_loss += loss

            # backward
            optimizer.zero_grad()  # clear the grad
            loss.backward()

            # gradient descent or adam step
            optimizer.step()

            _, predictions = pred.max(1)
            num_correct += (predictions == label).sum()
            num_samples += predictions.size(0)

        decoder_answer = float(num_correct) / float(num_samples) * 100
        print(f"saveckpt: {saveckpt}, epoch: {epoch}, train_decoder_answer: {decoder_answer}%\n")

        if epoch == cfg.epoch_num-1:
            train_last_epoch_loss = train_loss / iter
            train_last_epoch_decoding_answer = decoder_answer

        # ---------------------valid---------------------
        num_correct = 0
        num_samples = 0
        valid_loss = 0.0
        model.eval()
        for iter, (eeg, label) in enumerate(tqdm.tqdm(valid_loader, position=0, leave=True), start=1):
            with torch.no_grad():
                eeg = eeg.to(cfg.device)
                label = label.to(cfg.device)
                pred,_ = model(eeg)
                loss = criterion(pred, label)
                valid_loss = loss + valid_loss
                _, predictions = pred.max(1)
                num_correct += (predictions == label).sum()
                num_samples += predictions.size(0)

        decoder_answer = float(num_correct) / float(num_samples) * 100
        print(f"saveckpt: {saveckpt},epoch: {epoch}"
                f"valid loss: {valid_loss / iter} , valid_decoder_answer: {decoder_answer}%\n")

        # Please note that for the densenet model,
        # the result presented here is a classification accuracy of 1/128s rather than 1s
        if valid_loss_min>valid_loss / iter:
            valid_loss_min = valid_loss / iter
            torch.save(model.state_dict(), saveckpt)

        if epoch == cfg.epoch_num-1:
            valid_last_epoch_loss = valid_loss / iter
            valid_last_epoch_decoding_answer = decoder_answer

    return train_last_epoch_decoding_answer, valid_last_epoch_decoding_answer

def test_model(eegdata, eeglabel, saveckpt):

# ----------------------initial model------------------------

    model = CNN_baseline().to(cfg.device)
    # test using the current folded data
    x_test, y_test = eegdata, eeglabel

    # tough the train and valid process exist difference
    # the test_data is same,one second by one second
    test_dataset = end2end_dataset(x_test, y_test)
    # test the data one by one
    test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)

# -------------------------test--------------------------------------------
    # after some epochs, test model

    test_acc = 0
    model.load_state_dict(torch.load(saveckpt))
    model.eval()
    total_num = 0
    all_state = []
    all_label = []
    for iter, (eeg, label) in enumerate(tqdm.tqdm(test_loader, position=0, leave=True), start=1):
        with torch.no_grad():

            eeg = eeg.to(cfg.device)
            label = label.to(cfg.device)
            pred,state = model(eeg)
            all_state.append(state)
            all_label.append(label)

            _, predictions = pred.max(1)

            if predictions == label:
                test_acc += 1
            total_num = total_num + 1

    # Object of tSNE
    all_state = torch.cat(all_state, dim=0)
    targets = torch.cat(all_label, dim=0)
    # to np
    all_state = all_state.cpu().numpy()
    targets = targets.cpu().numpy()
    tsne = TSNE(n_components=2, random_state=2024)
    x_transformed = tsne.fit_transform(all_state)
    tsne_df = pd.DataFrame(np.column_stack((x_transformed, targets)), columns=['X', 'Y', "Targets"])
    tsne_df.loc[:, "Targets"] = tsne_df.Targets.astype(int)
    plt.figure(figsize=(12, 10))
    g = sns.FacetGrid(data=tsne_df, hue='Targets', height=8)
    g.map(plt.scatter, 'X', 'Y').add_legend()
    # plt.show()
    # save the figure

    os.makedirs('./tsne_test_xy/', exist_ok=True)
    savedir = saveckpt[2:-5]
    # replace / to _
    savedir = savedir.replace('/', '_')
    plt.savefig('./tsne_test_xy/' + savedir + ".png")
    plt.close()

    res = 100 * test_acc / total_num
    print(f"saveckpt: {saveckpt}, test_decoder_answer: {res}%\n")

    return res