import torch
import numpy as np
import random
import os
import warnings
import time

from data import load_mnist_data, split_data_stratified, create_dataloaders
from tools import add_gaussian_noise, visualize_samples, visualize_reconstructed_images, plot_errors
from models import Encoder, Decoder
from train  import baseline, adaptation, evaluation
from config import default_cfg, parse_arguments


# --- 6. Main Function ---
def main():
    """Orchestrates the data loading, processing, and visualization."""

    # Update it with any command-line arguments
    cfg = parse_arguments(default_cfg)

    log_list = []

    # Set a random seed for reproducibility (PyTorch and random)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)
    if device == 'cuda':
        print(' -- {}'.format(torch.cuda.get_device_name(0)))

    if cfg["random_seed"] is not None:
        random.seed(cfg["random_seed"])
        torch.manual_seed(cfg["random_seed"])
        np.random.seed(cfg["random_seed"])
        if device == 'cuda':
            torch.cuda.manual_seed(cfg["random_seed"])
            torch.cuda.manual_seed_all(cfg["random_seed"]) # if multi-GPU
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

    if cfg['adaptation_flag']:
        if cfg["adapt_loss_mode"] in ["homm", "cmd"]:
            job_folder = '{}{}_{}'.format(cfg["adapt_loss_mode"], cfg["highest_moment"], cfg["job_index"])
        else:
            job_folder = '{}_{}'.format(cfg["adapt_loss_mode"], cfg["job_index"])
    else:
        job_folder = '{}_{}'.format('baseline', cfg["job_index"])

    result_dir = 'results/{}/{}'.format(cfg["dataset_name"], job_folder)
    if os.path.exists(result_dir):
        warnings.warn(f"The directory '{result_dir}' already exists. Files may be overwritten.", UserWarning)
    else:
        os.makedirs(result_dir)

    print('result dir: ', result_dir)

    # 1. Load Data
    tr_images, tr_labels = load_mnist_data(selected_labels=cfg["selected_labels"], num_samples_per_label=cfg["num_samples_per_label"],
                                           train=True, dataset_name=cfg["dataset_name"])
    if tr_images.shape[0] == 0:
        print("No train data loaded, exiting.")
        return

    val_images, val_labels = load_mnist_data(selected_labels=cfg["selected_labels"], num_samples_per_label=None,
                                             train=False, dataset_name=cfg["dataset_name"])
    if val_images.shape[0] == 0:
        print("No test data loaded, exiting.")
        return

    # 2. Split Data
    print("Training data:")
    source_train_images, _, target_train_images, _ = split_data_stratified(tr_images, tr_labels,
                                                                           target_split_ratio=cfg["target_split_ratio"],
                                                                           random_state=cfg["random_seed"])
    print("Test data:")
    source_val_images, _, target_val_images, _ = split_data_stratified(val_images, val_labels,
                                                                       target_split_ratio=cfg["target_split_ratio"],
                                                                       random_state=cfg["random_seed"])

    # 3. Add Noise to Target Data
    noisy_target_train_images = add_gaussian_noise(
        target_train_images,
        mean=cfg["noise_mean"],
        std_dev=cfg["noise_sd"]
    )

    noisy_target_val_images = add_gaussian_noise(
        target_val_images,
        mean=cfg["noise_mean"],
        std_dev=cfg["noise_sd"]
    )

    # 4. Create DataLoaders
    source_train_loader = create_dataloaders(source_train_images, batch_size=cfg["batch_size"], drop_last=True)
    source_val_loader = create_dataloaders(source_val_images, batch_size=cfg["batch_size"], shuffle=False)

    # clean_target_train_loader = create_dataloaders(target_train_images, batch_size=cfg["batch_size"], drop_last=True)
    # clean_target_val_loader = create_dataloaders(target_val_images, batch_size=cfg["batch_size"], shuffle=False)

    noisy_target_train_loader = create_dataloaders(images=noisy_target_train_images, labels=target_train_images,
                                                   batch_size=cfg["batch_size"], drop_last=True)
    noisy_target_val_loader = create_dataloaders(images=noisy_target_val_images, labels=target_val_images,
                                                 batch_size=cfg["batch_size"], shuffle=False)

    # 5. Visualize Samples
    visualize_samples(source_train_images, num_images=cfg["num_visualize"], title="Source Domain Samples",
                      save_dir=result_dir, file_name='source_train_samples.png')

    visualize_samples(target_train_images, num_images=cfg["num_visualize"], title="Target Domain Samples",
                      save_dir=result_dir, file_name='target_val_samples.png')

    visualize_samples(noisy_target_train_images, num_images=cfg["num_visualize"], save_dir=result_dir,
                      title=r"Target Domain Samples (Noise Added: N({:.2f}, {:.2f}$^2$))".format(cfg['noise_mean'], cfg['noise_sd']),
                      file_name='target_val_samples_noisy.png',
    )

    encoder_f = Encoder()
    decoder_g = Decoder()

    opt = torch.optim.Adam(list(encoder_f.parameters()) + list(decoder_g.parameters()), lr=cfg["lr"])

    result_dict = {'err_src_tr': [],
                   'da_loss': [],
                   'err_src_val': [],
                   'err_trg_tr': [],
                   'err_trg_val': [],
                   }


    best_val_loss = float('inf')

    for epoch in range(cfg['num_epochs']):

        start_time = time.time()

        if cfg['adaptation_flag']:
            error_src_tr, da_loss, lamda = adaptation(source_tr_x=source_train_loader, target_tr_x=noisy_target_train_loader,
                                                      f=encoder_f, g=decoder_g, opt=opt, adapt_loss_mode=cfg["adapt_loss_mode"],
                                                      lamda=cfg["lambda"], highest_moment=cfg["highest_moment"],
                                                      geo_adapt_metric=cfg["geo_adapt_metric"], epoch=epoch+1,
                                                      initial_epochs=cfg["initial_epochs"], device=device)
            result_dict['da_loss'].append(da_loss)

        else:
            error_src_tr = baseline(source_train_loader, encoder_f, decoder_g, opt, device=device)

        end_time = time.time()

        result_dict['err_src_tr'].append(error_src_tr)

        error_src_val, _, _ = evaluation(f=encoder_f, g=decoder_g, data_loader=source_val_loader, device=device)
        error_trg_tr, _, _ = evaluation(f=encoder_f, g=decoder_g, data_loader=noisy_target_train_loader, device=device)
        error_trg_val, reconstructed_img, ground_truth_img = evaluation(f=encoder_f, g=decoder_g, data_loader=noisy_target_val_loader,
                                                                        device=device, return_images=True)

        result_dict['err_src_val'].append(error_src_val)
        result_dict['err_trg_tr'].append(error_trg_tr)
        result_dict['err_trg_val'].append(error_trg_val)

        ep_time = (end_time - start_time)/60.0

        if cfg['adaptation_flag']:
            output_log = (f'Epoch [{epoch + 1}/{cfg["num_epochs"]}], err_src_tr: {error_src_tr:.5f}, err_src_val: {error_src_val:.5f}, '
                          f'--------- , DA: {da_loss:.2e}, --------- , '
                          f'err_trg_tr: {error_trg_tr:.4f}, err_trg_val: {error_trg_val:.4f}, lambda: {lamda:.0e}, train_time (min): {ep_time:.2f}'
                          )
        else:
            output_log = (f'Epoch [{epoch + 1}/{cfg["num_epochs"]}], err_src_tr: {error_src_tr:.5f}, err_src_val: {error_src_val:.5f}, '
                          f'--------- , err_trg_tr: {error_trg_tr:.4f}, err_trg_val: {error_trg_val:.4f}, train_time (min): {ep_time:.2f}'
                          )

        log_list.append(output_log)

        if (epoch+1) % cfg['every_nth_epoch'] == 0:
            print(output_log, flush=True)

        if error_src_val < best_val_loss:
            best_val_loss = error_src_val
            best_epoch = output_log
            torch.save(encoder_f.state_dict(), os.path.join(result_dir, 'encoder_f.pt'))
            torch.save(decoder_g.state_dict(), os.path.join(result_dir, 'decoder_g.pt'))
            # print("Saved best model.")

    visualize_reconstructed_images(images1=ground_truth_img, images2=reconstructed_img, num_to_show=4,
                                   save_dir=result_dir, file_name='reconstructed_images.png')
    plot_errors(err_src=result_dict['err_src_tr'], err_trg=result_dict['err_trg_tr'], title='Training Error',
                save_dir=result_dir, file_name='training_error.png')
    plot_errors(err_src=result_dict['err_src_val'], err_trg=result_dict['err_trg_val'], title='Validation Error',
                save_dir=result_dir, file_name='validation_error.png')

    with open(os.path.join(result_dir, 'log.txt'), 'w') as file:
        file.writelines([string + '\n' for string in log_list])
        file.write('\n\n')
        file.write('Best Epoch:\n')
        file.write(best_epoch)

    # Save configuration parameters
    with open(os.path.join(result_dir, 'config.txt'), 'w') as f:
        f.write("Configuration Parameters:\n")
        f.write("=" * 30 + "\n\n")
        for key, value in cfg.items():
            f.write(f"{key}: {value}\n")
        
    
    print("--- Script Finished ---")


# --- Execute Main ---
if __name__ == "__main__":
    main()