import os

from models.tmae import CausalEEGAutoencoder2D
from training_utils import (
    concatenate_subjects,
    make_loader,
    save_35_subplots,
    save_plot_list,
    optimize_memory,
    set_dropout_rate,
    loss_with_l2_reg,
    log,
    construct_full_period_data_w_mask,
)

import numpy as np
from run_different_models import *
import torch
from torch import nn
from train import train, test
import copy
from log import log
import yaml
import argparse


def config_to_model_loss_init(dataset_config):
    loss_fn = nn.CrossEntropyLoss()

    if "tmaeformer_" in dataset_config.MODEL_NAME:
        model = tmaeformer_init(dataset_config)
        loss_fn = causal_masked_loss
    else:
        raise ValueError("Model not found")
    with open(os.path.join(dataset_config.exp_dir, "dataset_config.yaml"), "w") as f:
        yaml.dump(vars(dataset_config), f, default_flow_style=False)

    log(f"Model initialization: {model}", dataset_config.log_path)
    return model, loss_fn


def form_dataset(
    dataset_config,
    x_train_meta,
    y_train_meta,
    test_index,
    is_shuffle=True,
    is_float32=True,
    is_x_y_equivalent=False,
    is_in_fine_tune_aug=False,
):
    x_train_fine, y_train_fine = concatenate_subjects(
        x_train_meta, y_train_meta, test_index, is_x_y_equivalent
    )

    fine_tune_dataloader = make_loader(
        x_train_fine,
        y_train_fine,
        batch_size=dataset_config.BATCH_SIZE,
        shuffle=is_shuffle,
        is_float32=is_float32,
    )
    return fine_tune_dataloader


def fine_tune_model_for_each_subjects(
    dataset_config,
    x_train_meta,
    y_train_meta,
    x_test_meta,
    y_test_meta,
    model_meta,
    is_float32,
    loss_fn,
    is_x_y_equivalent=False,
    epochs=1000,
    training_step="regeneration",
    lr=1e-4,
):
    # only fine tune the model for each subject
    models_fine_tuned = []
    test_dataloader_list = []
    losses_for_subjects = []
    testing_acc_for_subjects = []
    num_subjects = dataset_config.num_subjects
    
    # Store original model state to avoid repeated deep copying
    original_state = model_meta.state_dict()
    
    print(dataset_config.num_subjects, "dataset_config.num_subjects")
    for subject in range(0, num_subjects):
        # Create a deep copy of the model to avoid parameter sharing
        model_sub = copy.deepcopy(model_meta)
        model_sub.to(dataset_config.device)
        
        optimizer = torch.optim.Adam(
            model_sub.parameters(), lr=lr, weight_decay=dataset_config.WD
        )
        test_index = [subject]

        fine_tune_dataloader = form_dataset(
            dataset_config,
            x_train_meta,
            y_train_meta,
            test_index,
            is_shuffle=True,
            is_float32=is_float32,
            is_x_y_equivalent=is_x_y_equivalent,
            is_in_fine_tune_aug=False,
        )
        test_dataloader = form_dataset(
            dataset_config,
            x_test_meta,
            y_test_meta,
            test_index,
            is_shuffle=False,
            is_float32=is_float32,
            is_x_y_equivalent=is_x_y_equivalent,
        )
        _, losses, testing_acc = train(
            fine_tune_dataloader,
            model_sub,
            loss_fn,
            optimizer,
            epochs,
            dataset_config,
            verbose=False,
            is_testing=True if training_step == "classification" else False,
            test_dataloader=test_dataloader,
        )
        losses_for_subjects.append(losses)
        if training_step == "classification":
            testing_acc_for_subjects.append(testing_acc)
        models_fine_tuned.append(model_sub)  # No need for deep copy since model_sub is already unique
        test_dataloader_list.append(test_dataloader)
        if testing_acc:
            log(f"fine tune for model each subject testing acc:,{subject} {testing_acc}", dataset_config.log_path)
    save_35_subplots(
        losses_for_subjects,
        os.path.join(dataset_config.exp_dir, training_step + "_fine_tune_losses.png"),
        plot_type="line",
    )
    if training_step == "classification":   
        save_35_subplots(
            testing_acc_for_subjects,
            os.path.join(dataset_config.exp_dir, training_step + "_fine_tune_testing_acc.png"),
            plot_type="line",
        )
    
    
    # Optimize memory after fine-tuning
    optimize_memory()
    
    return models_fine_tuned, test_dataloader_list


def _model_to_regenerate_data(dataset_config, model, train_dataloader):
    X_constructed = []
    y_constructed = []
    model.eval()
    for X, y in train_dataloader:
        X, y = X.to(dataset_config.device), y.to(dataset_config.device)
        pred = model(X)
        X_constructed.append(construct_full_period_data_w_mask(pred, X))
        y_constructed.append(y)
    
    # Optimize tensor operations to reduce memory copying
    X_cat = torch.cat(X_constructed, dim=0)
    y_cat = torch.cat(y_constructed, dim=0)
    
    # Clear the lists to free memory
    X_constructed.clear()
    y_constructed.clear()
    
    # Reshape and permute in one operation to reduce intermediate tensors
    x_train_meta = (
        X_cat.reshape(
            dataset_config.num_subjects,
            -1,
            dataset_config.multi_band,
            dataset_config.channels,
            dataset_config.total_len_time_period,
        )
        .permute((0, 4, 3, 2, 1))
        .cpu()
        .numpy()
    )
    
    y_train_meta = (
        (1 + y_cat.reshape(dataset_config.num_subjects, -1))
        .cpu()
        .numpy()
    )
    
    return x_train_meta, y_train_meta


def _model_to_regenerate_data_for_fine_tune(dataset_config, model, train_dataloader):
    X_constructed = []
    y_constructed = []
    model.eval()
    for X, y in train_dataloader:
        X, y = X.to(dataset_config.device), y.to(dataset_config.device)
        pred = model(X)
        X_constructed.append(construct_full_period_data_w_mask(pred, X,known_len=dataset_config.time_period_real))
        y_constructed.append(y)
        # (35, 53, 9, 3, 200) (35, 200)
    # print(torch.cat(X_constructed, dim=0).shape,'torch.cat(y_constructed, dim=0).shape)')
    
    # Optimize tensor operations to reduce memory copying
    X_cat = torch.cat(X_constructed, dim=0)
    y_cat = torch.cat(y_constructed, dim=0)
    
    # Clear the lists to free memory
    X_constructed.clear()
    y_constructed.clear()
    
    x_train_meta = (
        X_cat.reshape(
            1,
            -1,
            dataset_config.multi_band,
            dataset_config.channels,
            dataset_config.total_len_time_period,
        )
        .permute((0, 4, 3, 2, 1))
        .detach()
        .cpu()
        .numpy()
    )
    y_train_meta = (
        (1 + y_cat.reshape(1, -1))
        .detach()
        .cpu()
        .numpy()
    )
    return x_train_meta, y_train_meta


def regenerate_data_from_model(
    dataset_config,
    model,
    x_train_meta_tilde,
    y_train_meta,
    x_test_meta_tilde,
    y_test_meta,
    is_float32,
):
    train_dataloader = form_dataset(
        dataset_config,
        x_train_meta_tilde,
        y_train_meta,
        np.arange(dataset_config.num_subjects),
        is_shuffle=False,
        is_float32=is_float32,
    )
    test_dataloader = form_dataset(
        dataset_config,
        x_test_meta_tilde,
        y_test_meta,
        np.arange(dataset_config.num_subjects),
        is_shuffle=False,
        is_float32=is_float32,
    )
    x_train_meta, y_train_meta = _model_to_regenerate_data(
        dataset_config, model, train_dataloader
    )
    x_test_meta, y_test_meta = _model_to_regenerate_data(
        dataset_config, model, test_dataloader
    )
    return x_train_meta, y_train_meta, x_test_meta, y_test_meta


def regenerate_data_from_model_fine_tune(
    dataset_config,
    model_list,
    x_train_meta_tilde,
    x_train_meta,
    y_train_meta,
    x_test_meta_tilde,
    x_test_meta,
    y_test_meta,
    is_float32,
    is_x_y_equivalent=False,
):
    x_train_meta_l = []
    y_train_meta_l = []
    x_test_meta_l = []
    y_test_meta_l = []
    
    for subject in range(0, dataset_config.num_subjects):
        train_dataloader = form_dataset(
            dataset_config,
            x_train_meta_tilde,
            y_train_meta,
            [subject],
            is_shuffle=False,
            is_float32=is_float32,
            is_x_y_equivalent=is_x_y_equivalent,
        )
        test_dataloader = form_dataset(
            dataset_config,
            x_test_meta_tilde,
            y_test_meta,
            [subject],
            is_shuffle=False,
            is_float32=is_float32,
            is_x_y_equivalent=is_x_y_equivalent,
        )
        model_sub = model_list[subject]
        model_sub.eval()
        
        # Process train data
        x_train_meta_for_one_subject, y_train_meta_for_one_subject = (
            _model_to_regenerate_data_for_fine_tune(
                dataset_config, model_sub, train_dataloader
            )
        )
        
        # Process test data
        x_test_meta_for_one_subject, y_test_meta_for_one_subject = (
            _model_to_regenerate_data_for_fine_tune(
                dataset_config, model_sub, test_dataloader
            )
        )
        
        # Append results and clear intermediate variables to free memory
        x_train_meta_l.append(x_train_meta_for_one_subject)
        y_train_meta_l.append(y_train_meta_for_one_subject)
        x_test_meta_l.append(x_test_meta_for_one_subject)
        y_test_meta_l.append(y_test_meta_for_one_subject)
        
        # Clear intermediate variables
        del x_train_meta_for_one_subject, y_train_meta_for_one_subject
        del x_test_meta_for_one_subject, y_test_meta_for_one_subject
    
    # Concatenate all results at once to reduce memory fragmentation
    x_train_meta = np.concatenate(x_train_meta_l, axis=0)
    y_train_meta = np.concatenate(y_train_meta_l, axis=0)
    x_test_meta = np.concatenate(x_test_meta_l, axis=0)
    y_test_meta = np.concatenate(y_test_meta_l, axis=0)
    
    # Clear lists to free memory
    x_train_meta_l.clear()
    y_train_meta_l.clear()
    x_test_meta_l.clear()
    y_test_meta_l.clear()
    
    return x_train_meta, y_train_meta, x_test_meta, y_test_meta



import random


import random
import numpy as np


def channel_swap_augmentation(x_data, swap_duration=3, swap_probability=0.3):

    x_augmented = x_data.copy()
    
    # Handle 3D, 4D, and 5D tensors
    if len(x_data.shape) == 3:
        # Shape: (time_len, channels, bands)
        time_len, channels, bands = x_data.shape


    if True:
        for start_pos in range(0, time_len - swap_duration + 1, swap_duration):
            # Randomly decide if we should apply a swap here
            if np.random.random() < swap_probability:
                # Randomly select two different channels to swap
                ch1, ch2 = np.random.choice(channels, 2, replace=False)
                
                # Swap the channels for this time period
                temp = x_augmented[start_pos:start_pos + swap_duration, ch1, :].copy()
                x_augmented[start_pos:start_pos + swap_duration, ch1, :] = \
                    x_augmented[start_pos:start_pos + swap_duration, ch2, :]
                x_augmented[start_pos:start_pos + swap_duration, ch2, :] = temp
    
    return x_augmented


def time_shift_augmentation(x_data, shift_range=(5, 15), zero_pad_probability=0.4):

    x_augmented = x_data.copy()
    
    # Handle 3D, 4D, and 5D tensors
    if len(x_data.shape) == 3:
        # Shape: (time_len, channels, bands)
        time_len, channels, bands = x_data.shape
    if True:
 
        # For 3D tensor (time_len, channels, bands)
        # Randomly decide if we should apply time shift to this sample
        if np.random.random() < zero_pad_probability:
            # Randomly select shift amount
            shift_amount = np.random.randint(shift_range[0], shift_range[1] + 1)
            
            # Create shifted data
            shifted_data = np.zeros_like(x_augmented)
            
            # Copy data with shift (original 0-50 becomes shift_amount-50)
            if shift_amount < time_len:
                # Copy the shifted portion
                shifted_data[shift_amount:, :, :] = x_augmented[:time_len - shift_amount, :, :]
                # The beginning (0 to shift_amount-1) remains zeros
            
            # Update the augmented data
            x_augmented = shifted_data
    
    return x_augmented


def generate_augmented_data(x_train_meta, y_train_meta, chunk_size=5, trials=5, target=40, 
                           chunk_ratio=0.5, channel_swap_ratio=0.3, time_shift_ratio=0.3,
                           apply_channel_swap=True, apply_time_shift=True):
    if chunk_ratio+channel_swap_ratio+time_shift_ratio ==0:
        return x_train_meta, y_train_meta, x_train_meta, y_train_meta
    subjects, time_len, channels, bands, samples = x_train_meta.shape
    
    # Calculate actual trials from the data shape
    actual_trials = samples // target
    if samples % target != 0:
        raise ValueError(f"Cannot divide {samples} samples by {target} targets evenly")
    
    # Use actual trials instead of the parameter if they don't match
    if actual_trials != trials:
        print(f"Warning: Expected {trials} trials but data has {actual_trials} trials. Using actual trials.")
        trials = actual_trials
    
    x_train_meta_flat = x_train_meta.reshape(subjects, time_len, channels, bands, target, trials)
    y_train_meta_flat = y_train_meta.reshape(subjects, target, trials)
    
    # Calculate number of samples for each augmentation type
    num_chunk_augmented = int(chunk_ratio * subjects)
    num_channel_swap_augmented = int(channel_swap_ratio * subjects) if apply_channel_swap else 0
    num_time_shift_augmented = int(time_shift_ratio * subjects) if apply_time_shift else 0
    
    x_data_list = []
    y_data_list = []
    
    # 1. Generate augmented data using original chunk-based method
    print(f"Generating {num_chunk_augmented} chunk-based augmented samples...")
    for l in range(num_chunk_augmented):
        x_data = np.zeros((time_len, channels, bands, target, trials), dtype=np.float32)
        y_data = np.zeros((target, trials), dtype=np.float32)
        for __ in range(trials):
            for k in range(target):
                for _ in range(0, time_len, chunk_size):
                    i = np.random.randint(subjects)
                    j = np.random.randint(trials)
                    start_idx = _
                    end_idx = _ + chunk_size
                    x_data[start_idx:end_idx, :, :, k, __] = x_train_meta_flat[i, start_idx:end_idx, :, :, k, j]
                    y_data[k, __] = y_train_meta_flat[i, k, j]
        # Reshape to (time_len, channels, bands, target*trials)
        x_data_reshaped = x_data.reshape(time_len, channels, bands, target*trials)
        y_data_reshaped = y_data.reshape(target*trials)
        x_data_list.append(x_data_reshaped[None, ...])  # Add batch dimension
        y_data_list.append(y_data_reshaped[None, ...])
    
    # 2. Generate channel swap augmented data (mimic chunk-based shape)
    if apply_channel_swap and num_channel_swap_augmented > 0:
        print(f"Generating {num_channel_swap_augmented} channel swap augmented samples...")
        for l in range(num_channel_swap_augmented):
            x_data = np.zeros((time_len, channels, bands, target, trials), dtype=np.float32)
            y_data = np.zeros((target, trials), dtype=np.float32)
            for __ in range(trials):
                for k in range(target):
                    # Randomly select a subject from original data
                    subject_idx = np.random.randint(subjects)
                    # Get original data for this subject, target, trial
                    x_original = x_train_meta_flat[subject_idx, :, :, :, k, __]  # (time_len, channels, bands)
                    y_original = y_train_meta_flat[subject_idx, k, __]  # scalar
                    # Apply channel swap augmentation
                    x_aug = channel_swap_augmentation(x_original, swap_duration=20, swap_probability=0.3)
                    x_data[:, :, :, k, __] = x_aug
                    y_data[k, __] = y_original
            # Reshape to (time_len, channels, bands, target*trials)
            x_data_reshaped = x_data.reshape(time_len, channels, bands, target*trials)
            y_data_reshaped = y_data.reshape(target*trials)
            x_data_list.append(x_data_reshaped[None, ...])
            y_data_list.append(y_data_reshaped[None, ...])
    
    # 3. Generate time shift augmented data (mimic chunk-based shape)
    if apply_time_shift and num_time_shift_augmented > 0:
        print(f"Generating {num_time_shift_augmented} time shift augmented samples...")
        for l in range(num_time_shift_augmented):
            x_data = np.zeros((time_len, channels, bands, target, trials), dtype=np.float32)
            y_data = np.zeros((target, trials), dtype=np.float32)
            for __ in range(trials):
                for k in range(target):
                    # Randomly select a subject from original data
                    subject_idx = np.random.randint(subjects)
                    # Get original data for this subject, target, trial
                    x_original = x_train_meta_flat[subject_idx, :, :, :, k, __]  # (time_len, channels, bands)
                    y_original = y_train_meta_flat[subject_idx, k, __]  # scalar
                    # Apply time shift augmentation
                    x_aug = time_shift_augmentation(x_original, shift_range=(5, 15), zero_pad_probability=0.4)
                    x_data[:, :, :, k, __] = x_aug
                    y_data[k, __] = y_original
            # Reshape to (time_len, channels, bands, target*trials)
            x_data_reshaped = x_data.reshape(time_len, channels, bands, target*trials)
            y_data_reshaped = y_data.reshape(target*trials)
            x_data_list.append(x_data_reshaped[None, ...])
            y_data_list.append(y_data_reshaped[None, ...])
    
    # Concatenate all augmented data
    x_data_np = np.concatenate(x_data_list, axis=0)
    y_data_np = np.concatenate(y_data_list, axis=0)
    
    total_augmented = num_chunk_augmented + num_channel_swap_augmented + num_time_shift_augmented
    print(f"Total augmented samples generated: {total_augmented}")
    print(f"  - Chunk-based: {num_chunk_augmented}")
    print(f"  - Channel swap: {num_channel_swap_augmented}")
    print(f"  - Time shift: {num_time_shift_augmented}")
    
    # Concatenate with original data
    x_train_meta_aug = np.concatenate((x_train_meta, x_data_np), axis=0)
    y_train_meta_aug = np.concatenate((y_train_meta, y_data_np), axis=0)
    return x_train_meta_aug.astype(np.float32), y_train_meta_aug.astype(np.float32), x_data_np.astype(np.float32), y_data_np.astype(np.float32)
