import os
from models.cvt import CvT
from models.guney_net import GuneyNet
from models.tmae import CausalEEGAutoencoder2D
from training_utils import (
    config_to_model_loss_init,
    form_dataset,
    save_plot_list,
    optimize_memory,
    set_dropout_rate,
    loss_with_l2_reg,
    log,
    save_35_subplots,
    construct_full_period_data_w_mask,
)

import numpy as np
from run_different_models import *
import torch
from torch import nn
from train import train, test
import copy
from log import log
import yaml
import argparse
from run_training import *
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from test import plot_t_sne
import time

def run_train_model_v3(x, y, dataset_config, cv_index = 0):

    num_subjects = dataset_config.num_subjects
    n_rows = max(1, dataset_config.EPOCHS // 100)
    acc_W_o_tune = np.zeros((n_rows, num_subjects))
    is_float32 = dataset_config.MODEL_NAME != "guney"
    acc = np.zeros((num_subjects, 1))
    x = x.reshape(
        x.shape[0],
        x.shape[1],
        x.shape[2],
        x.shape[3],
        dataset_config.target,
        dataset_config.trials,
    )
    y = y.reshape(-1, dataset_config.target, dataset_config.trials)
    x_train_meta, x_test_meta = x[:, :, :, :, :, : dataset_config.trials_train].reshape(
        *x.shape[:4], -1
    ), x[:, :, :, :, :, dataset_config.trials_train :].reshape(*x.shape[:4], -1)
    y_train_meta, y_test_meta = y[:, :, : dataset_config.trials_train].reshape(
        *y.shape[:1], -1
    ), y[:, :, dataset_config.trials_train :].reshape(*y.shape[:1], -1)
    train_dataloader = form_dataset(
        dataset_config,
        x_train_meta,
        y_train_meta,
        np.arange(len(x_train_meta)),
        is_shuffle=True,
        is_float32=is_float32,
    )
    print(x_train_meta.shape, "before augment")
    x_train_meta_aug, y_train_meta_aug, x_data_np, y_data_np = generate_augmented_data(
        x_train_meta,
        y_train_meta,
        chunk_size=dataset_config.chunk_size,
        trials=dataset_config.trials_train,
        target=dataset_config.target,
        chunk_ratio= dataset_config.chunk_ratio_regen,  # Use augment_ratio as chunk_ratio for backward compatibility
        channel_swap_ratio=dataset_config.channel_swap_ratio_regen,  # Default value for channel swap
        time_shift_ratio=dataset_config.time_shift_ratio_regen,    # Default value for time shift
        apply_channel_swap=True,
        apply_time_shift=True,
    )
    print(x_train_meta_aug.shape, "after augment")
    if "tmae" in dataset_config.MODEL_NAME:
        x_train_meta_tilde_aug = x_train_meta_aug[
            :, : dataset_config.time_period_real, :, :, :
        ]
        x_train_meta_tilde = x_train_meta[:, : dataset_config.time_period_real, :, :, :]
        x_test_meta_tilde = x_test_meta[:, : dataset_config.time_period_real, :, :, :]
        print(
            x_train_meta_tilde_aug.shape,
            x_train_meta_aug.shape,
            is_float32,
            "tmae dtype",
        )
        train_dataloader = form_dataset(
            dataset_config,
            x_train_meta_tilde_aug,
            x_train_meta_aug,
            np.arange(len(x_train_meta_tilde_aug)),
            is_shuffle=True,
            is_float32=is_float32,
            is_x_y_equivalent=True,
        )
    model, loss_fn = config_to_model_loss_init(dataset_config)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=(
            dataset_config.LR
            if "tmae" not in dataset_config.MODEL_NAME
            else dataset_config.LR_tmae
        ),
        weight_decay=dataset_config.WD,
    )
    if "tmae" in dataset_config.MODEL_NAME:
        epochs = dataset_config.REGENERATION_EPOCHS
    else:
        epochs = dataset_config.EPOCHS
    start_time = time.time()
    models, losses,_ = train(
        train_dataloader,
        model,
        loss_fn,
        optimizer,
        epochs,
        dataset_config,
        verbose=False if "tmae" in dataset_config.MODEL_NAME else True,
        is_return_models=dataset_config.verbose,
    )
    end_time = time.time()
    log(f"generation train in {end_time - start_time} seconds", dataset_config.log_path)
    print("first train done")
    save_plot_list(
        losses, os.path.join(dataset_config.exp_dir, "step_1_train_loss.png")
    )
    if "tmae" in dataset_config.MODEL_NAME and (
        "cvt" in dataset_config.MODEL_NAME or "guney" in dataset_config.MODEL_NAME
    ):
        if dataset_config.fine_tune_regeneration:
            start_time = time.time()
            models_fine_tuned, _ = fine_tune_model_for_each_subjects(
                dataset_config,
                x_train_meta_tilde,
                x_train_meta,
                x_test_meta_tilde,
                x_test_meta,
                model,
                is_float32,
                loss_fn,
                is_x_y_equivalent=True,
                epochs=dataset_config.EPOCHS_FINE_TUNE_FOR_REG,
                training_step="regeneration",
                lr=dataset_config.LR_tmae_fine_tune,
            )
            end_time = time.time()
            log(f"regenerate data from model fine tune {end_time - start_time} seconds", dataset_config.log_path)
            print("fine tune done")
            start_time = time.time()
            x_train_meta, y_train_meta, x_test_meta, y_test_meta = (
                regenerate_data_from_model_fine_tune(
                    dataset_config,
                    models_fine_tuned,
                    x_train_meta_tilde,
                    x_train_meta,
                    y_train_meta,
                    x_test_meta_tilde,
                    x_test_meta,
                    y_test_meta,
                    is_float32,
                )
            )
            end_time = time.time()
            log(f"regenerate data from model fine tune done ,doing testing for stage 1{end_time - start_time} seconds", dataset_config.log_path)
            # np.save(dataset_config.exp_dir + '/x_train_generated_'+str(cv_index)+'.npy', x_train_meta)
            # np.save(dataset_config.exp_dir + '/y_train_generated_'+str(cv_index)+'.npy', y_train_meta)
            # np.save(dataset_config.exp_dir + '/x_test_generated_'+str(cv_index)+'.npy', x_test_meta)
            # np.save(dataset_config.exp_dir + '/y_test_generated_'+str(cv_index)+'.npy', y_test_meta)


        else:
            x_train_meta, y_train_meta, x_test_meta, y_test_meta = (
                regenerate_data_from_model(
                    dataset_config,
                    model,
                    x_train_meta_tilde,
                    y_train_meta,
                    x_test_meta_tilde,
                    y_test_meta,
                    is_float32,
                )
            )
        is_float32 = "guney" not in dataset_config.MODEL_NAME
        print(x_train_meta.shape,x_data_np.shape, 'before final augmentation')
        x_train_meta_aug, y_train_meta_aug, x_data_np, y_data_np = generate_augmented_data(
            x_train_meta,
            y_train_meta,
            chunk_size=dataset_config.chunk_size,
            trials=dataset_config.trials_train,
            target=dataset_config.target,
            chunk_ratio=dataset_config.chunk_ratio_guney,#dataset_config.augment_ratio,  # Use augment_ratio as chunk_ratio for backward compatibility
            channel_swap_ratio=dataset_config.channel_swap_ratio_guney,  # Default value for channel swap
            time_shift_ratio=dataset_config.time_shift_ratio_guney,    # Default value for time shift
            apply_channel_swap=True,
            apply_time_shift=True,
        )
        

        train_dataloader = form_dataset(
            dataset_config,
            x_train_meta_aug,
            y_train_meta_aug,
            np.arange(len(x_train_meta_aug)),
            is_shuffle=True,
            is_float32=is_float32,
        )

        if ( dataset_config.MODEL_NAME == "tmaeformer_guney_moe"):
            if dataset_config.MODEL_NAME == "tmaeformer_guney_moe":
                model = guney_int_moe(dataset_config)
            if dataset_config.guney_origin:
                loss_fn = loss_with_l2_reg(
                    nn.CrossEntropyLoss(), model, dataset_config.L2_REG
                )
        log(f"Model initialization for step 2: {model}", dataset_config.log_path)    
        model.to(dataset_config.device)
        optimizer = torch.optim.Adam(
            model.parameters(), lr=dataset_config.LR, weight_decay=dataset_config.WD
        )
        
        print(len(train_dataloader),'train guney dataset')
        start_time = time.time()
        models, losses, _ = train(
            train_dataloader,
            model,
            loss_fn,
            optimizer,
            dataset_config.EPOCHS,
            dataset_config,
            verbose=False,
            is_return_models=dataset_config.verbose,
        )
        end_time = time.time()
        log(f"classification train in {end_time - start_time} seconds", dataset_config.log_path)
        save_plot_list(
            losses,
            os.path.join(dataset_config.exp_dir, "classification_train_loss.png"),
        )

    model_meta = model.copy_model() if hasattr(model, 'copy_model') else model


    if 'moe' in dataset_config.MODEL_NAME:
        set_dropout_rate(
            model,
            "spatial_dropout",
            dataset_config.guney_origin_spatial_dropout_fine_tune,
        )
        set_dropout_rate(
            model,
            "spatial_dropout_1",
            dataset_config.guney_origin_spatial_dropout_fine_tune,
        )
        set_dropout_rate(
            model,
            "spatial_dropout_2",
            dataset_config.guney_origin_spatial_dropout_fine_tune,
        )
        set_dropout_rate(
            model,
            "spatial_dropout_3",
            dataset_config.guney_origin_spatial_dropout_fine_tune,
        )
        set_dropout_rate(
            model,
            "time1_dropout",
            dataset_config.guney_origin_time1_dropout_fine_tune,
        )

    if dataset_config.MODEL_NAME != "tmae":
        start_time = time.time()
        models_fine_tuned, test_dataloader_list = fine_tune_model_for_each_subjects(
            dataset_config,
            x_train_meta,
            y_train_meta,
            x_test_meta,
            y_test_meta,
            model_meta,
            is_float32,
            loss_fn,
            epochs=dataset_config.EPOCHS_FINE_TUNE,
            training_step="classification",
            lr=dataset_config.LR_fine_tune,
        )
        end_time = time.time()
        log(f"fine tune for classification in train {end_time - start_time} seconds", dataset_config.log_path)

    pred_array_all = []
    y_array_all = []
    start_time = time.time()
    for subject in range(0, num_subjects):
        model_sub = models_fine_tuned[subject]
        test_dataloader = test_dataloader_list[subject]
        if dataset_config.verbose:
            for i, model in enumerate(models):
                acc_W_o_tune[i, subject] = test(
                    test_dataloader,
                    model,
                    loss_fn,
                    dataset_config.device,
                    is_testing_generation=(dataset_config.MODEL_NAME == "tmae"),
                )
        acc[subject], pred_array, y_array = test(
            test_dataloader,
            model_sub,
            loss_fn,
            dataset_config.device,
            is_testing_generation=(dataset_config.MODEL_NAME == "tmae"),
            is_plot_t_sne=True,
        )
        pred_array_all.append(pred_array)
        y_array_all.append(y_array)
                
        # if len(pred_array_all) ==35:
        #     pred_array_all = np.concatenate(pred_array_all,axis=0)
        #     y_array_all = np.concatenate(y_array_all,axis=0)
        #     np.save(dataset_config.exp_dir + '/pred_array_all_'+str(cv_index)+'.npy', pred_array_all)
        #     np.save(dataset_config.exp_dir + '/y_array_all_'+str(cv_index)+'.npy', y_array_all)


            # plot_t_sne(np.concatenate(pred_array_all,axis=0), np.concatenate(y_array_all,axis=0), '/home/reylouis/jiuvvi/eeg/alg/exp_results/benchmark/t-SNE_visualization_decoder_with_vit.png')
        if dataset_config.verbose:
            acc_W_o_tune[-1, subject] = test(
                test_dataloader,
                model_meta,
                loss_fn,
                dataset_config.device,
                is_testing_generation=(dataset_config.MODEL_NAME == "tmae"),
            )
        log(
            f"Subject {subject + 1} Accuracy: {acc_W_o_tune[:,subject]}",
            dataset_config.log_final_results_path,
        )
        log(
            f"Subject {subject + 1} Accuracy: {acc[subject]}",
            dataset_config.log_final_results_path,
        )
    end_time = time.time()
    log(f"testing for all subjects in {end_time - start_time} seconds", dataset_config.log_path)
    log(
        f"Mean Accuracy: {np.mean(acc)} +- {np.std(acc)}",
        dataset_config.log_final_results_path,
    )
    log(
        f"Mean Accuracy: {np.mean(acc_W_o_tune,axis = 1)} +- {np.std(acc_W_o_tune,axis =1)}",
        dataset_config.log_final_results_path,
    )
    return (np.mean(acc), np.std(acc))
