import os
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch import flatten
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm

from preprocess import filter_eeg, form_dataset_config, preprocess_data
from training_utils import concatenate_subjects, make_loader, save_plot_list
from test import test
import copy

# Import mixed precision training
try:
    from torch.cuda.amp import autocast, GradScaler
    MIXED_PRECISION_AVAILABLE = True
except ImportError:
    MIXED_PRECISION_AVAILABLE = False


def train(dataloader, model, loss_fn, optimizer, epochs, dataset_config, verbose=False, use_mixed_precision=True,is_return_models=False,is_testing = False, test_dataloader = None):
    # doing 4 times regen, fintune, classification, finetune. doing test validation on classification fine tune; doing aug loss on classification.finetune
    model.to(dataset_config.device)
    models = []
    model.train()
    losses = []
    aux_loss = []
    
    # Initialize mixed precision training if available and requested
    scaler = None
    if use_mixed_precision and MIXED_PRECISION_AVAILABLE and dataset_config.device.startswith('cuda'):
        scaler = GradScaler()
        use_mixed_precision = True
    else:
        use_mixed_precision = False
    testing_acc = []
    for epoch in range(epochs):
        loss_total = 0
        if hasattr(model, 'last_aux_loss'):
            aux_loss.append(0)
        if is_testing and test_dataloader is not None and epoch % 50 == 0 and epoch != 0:
            testing_acc.append(test(
                test_dataloader,
                model,
                loss_fn,
                dataset_config.device,
                is_testing_generation=(dataset_config.MODEL_NAME == "tmae"),
            ))
            model.train()

        if verbose:
            for batch, (X, y) in tqdm(
                enumerate(dataloader),
                total=len(dataloader),
                desc=f"Epoch {epoch+1}/{epochs}",
            ):

                X, y = X.to(dataset_config.device), y.to(dataset_config.device)
                optimizer.zero_grad()
                
                # Use mixed precision if available
                if use_mixed_precision:
                    with autocast():
                        pred = model.forward(X)
                        loss = loss_fn(pred, y)
                        if hasattr(model, 'last_aux_loss') and dataset_config.add_aux_loss :
                            loss += model.last_aux_loss
                            aux_loss[-1]+=model.last_aux_loss.detach().cpu().numpy()

                    
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    # Compute prediction error
                    pred = model.forward(X)
                    loss = loss_fn(pred, y)
                    if hasattr(model, 'last_aux_loss') and dataset_config.add_aux_loss :
                        loss += model.last_aux_loss
                        aux_loss[-1]+=model.last_aux_loss.detach().cpu().numpy()
                    # Backpropagation
                    loss.backward()
                    optimizer.step()
                
                loss_total += loss.detach().cpu().numpy()

            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")
        else:
            for batch, (X, y) in enumerate(dataloader):
                X, y = X.to(dataset_config.device), y.to(dataset_config.device)
                optimizer.zero_grad()
                if use_mixed_precision:
                    with autocast():
                        pred = model.forward(X)
                        loss = loss_fn(pred, y)
                        if hasattr(model, 'last_aux_loss') and dataset_config.add_aux_loss :
                            loss += model.last_aux_loss
                            aux_loss[-1]+=model.last_aux_loss.detach().cpu().numpy()

                    
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    # Compute prediction error
                    pred = model.forward(X)
                    loss = loss_fn(pred, y)
                    if hasattr(model, 'last_aux_loss') and dataset_config.add_aux_loss :
                        loss += model.last_aux_loss
                        aux_loss[-1]+=model.last_aux_loss.detach().cpu().numpy()
                    # Backpropagation
                    loss.backward()
                    optimizer.step()

                loss_total += loss.detach().cpu().numpy()
        losses.append(loss_total)
        loss /= len(dataloader)
        if  is_return_models and epoch % 100 == 0 and epoch != 0:
            if hasattr(model, 'copy_model'):
                models.append(model.copy_model())
            else:
                models.append(copy.deepcopy(model))
    if dataset_config.add_aux_loss:
        if not is_testing:
            save_plot_list(
                aux_loss,
                os.path.join(dataset_config.exp_dir, "classification_train_aux_loss.png"),
            )
    return models, losses, testing_acc




def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
