import warnings
import os
import time
import argparse
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from torch.nn.parallel import DistributedDataParallel

from dataloaders.train_dataset_reverb_utilis import create_dataloader, create_dataset
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator import (SEMamba, SEAtMamba, HyperSEMamba, SEMambaAt, SEMambaReAt,
                              SEMambaCo2dReAt, SEMambaCoDe2dReAt, SEHyperMambaCoDe2dReAt, SEMambaCoDe2dReGuMaAt)
from models.loss import pesq_score, phase_losses
from models.discriminator import MetricDiscriminator, batch_pesq
from utils.util import (
    load_ckpts, load_optimizer_states, save_checkpoint,
    build_env, load_config, initialize_seed,
    print_gpu_info, log_model_info)
from models.simulate_paths import (
    instance_simulator, process_signals_through_primary_path,
    process_signals_through_secondary_path, randomize_reverberation_time)


warnings.simplefilter(action='ignore', category=FutureWarning)
torch.backends.cudnn.benchmark = True


def setup_optimizers(models, cfg):
    """Set up optimizers for the models."""
    generator, discriminator = models
    learning_rate = cfg['training_cfg']['learning_rate']
    betas = (cfg['training_cfg']['adam_b1'], cfg['training_cfg']['adam_b2'])

    optim_g = optim.AdamW(generator.parameters(), lr=learning_rate, betas=betas)
    optim_d = optim.AdamW(discriminator.parameters(), lr=learning_rate, betas=betas)

    return optim_g, optim_d


def setup_schedulers(optimizers, cfg, last_epoch):
    """Set up learning rate schedulers."""
    optim_g, optim_d = optimizers
    lr_decay = cfg['training_cfg']['lr_decay']

    scheduler_g = optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay, last_epoch=last_epoch)
    if optim_d is not None:
        scheduler_d = optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay, last_epoch=last_epoch)
    else:
        scheduler_d = None

    return scheduler_g, scheduler_d


def get_simulator(cfg, device):
    try:
        version = cfg['rir_cfg']['path_version']
    except KeyError:
        version = cfg['rir_cfg']['version']
    print("Path RIR version: ", version)
    # Create simulator
    if cfg['rir_cfg']['type'] == "RIR":
        reverberation_times, simulator = instance_simulator(
            simulator_type=cfg['rir_cfg']['type'], sr=cfg['stft_cfg']['sampling_rate'],
            reverberation_times=cfg['rir_cfg']['reverberation_times'], rir_samples=cfg['rir_cfg']['rir_samples'],
            device=device, hp_filter=cfg['rir_cfg']['hp_filter'], version=version)
    elif cfg['rir_cfg']['type'] == "PyRoom":
        reverberation_times, simulator = instance_simulator(
            simulator_type=cfg['rir_cfg']['type'], sr=cfg['stft_cfg']['sampling_rate'],
            reverberation_times=cfg['rir_cfg']['reverberation_times'], rir_samples=cfg['rir_cfg']['rir_samples'],
            device=device, version=version)
    else:
        raise ValueError("Unknown simulator type")

    return reverberation_times, simulator


def get_model_and_load_checkpoints(rank, device, num_gpus, args, use_discriminator, cfg):
    # Create models
    if cfg['model_cfg']['model_type'] == "SEMambaCoDe2dReGuMaAt":
        generator = SEMambaCoDe2dReGuMaAt(cfg).to(device)
    elif cfg['model_cfg']['model_type'] == "SEHyperMambaCoDe2dReAt":
        generator = SEHyperMambaCoDe2dReAt(cfg).to(device)
    elif cfg['model_cfg']['model_type'] == "SEMambaCoDe2dReAt":
        # SEMambaCoDe2dReAt is the ASE-TM model
        generator = SEMambaCoDe2dReAt(cfg).to(device)
    elif cfg['model_cfg']['model_type'] == "SEMambaCo2dReAt":
        generator = SEMambaCo2dReAt(cfg).to(device)
    elif cfg['model_cfg']['model_type'] == "SEMambaReAt":
        generator = SEMambaReAt(cfg).to(device)
    elif cfg['model_cfg']['model_type'] == "SEAtMamba":
        generator = SEAtMamba(cfg).to(device)
    elif cfg['model_cfg']['model_type'] in ["SEMamba", "SEHyperMamba"]:
        generator = SEMamba(cfg).to(device)
    elif cfg['model_cfg']['model_type'] == "SEMambaAt":
        generator = SEMambaAt(cfg).to(device)
    elif cfg['model_cfg']['model_type'] == "HyperSEMamba":
        generator = HyperSEMamba(cfg).to(device)
    else:
        raise ValueError("Unknown model type")
    discriminator = MetricDiscriminator().to(device)

    if rank == 0:
        log_model_info(rank, generator, args.exp_path)

    # Track last saved checkpoints
    state_dict_g, state_dict_do, steps, last_epoch, last_generator_checkpoint, last_discriminator_checkpoint = \
        load_ckpts(args, device)
    if state_dict_g is not None:
        generator.load_state_dict(state_dict_g['generator'], strict=False)
    if state_dict_do is not None:
        discriminator.load_state_dict(state_dict_do['discriminator'], strict=False)

    if num_gpus > 1 and torch.cuda.is_available():
        generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
        discriminator = DistributedDataParallel(discriminator, device_ids=[rank]).to(device)

    # Create optimizer and schedulers
    optimizers = setup_optimizers((generator, discriminator), cfg)
    load_optimizer_states(optimizers, state_dict_do, state_dict_g)
    optim_g, optim_d = optimizers
    if not use_discriminator:
        optimizers = (optim_g, None)
    scheduler_g, scheduler_d = setup_schedulers(optimizers, cfg, last_epoch)

    return (generator, discriminator, optim_g, optim_d, scheduler_g, scheduler_d,
            steps, last_epoch, last_generator_checkpoint, last_discriminator_checkpoint)


def checkpoiting(
        generator, discriminator, optim_g, optim_d, steps, epoch, exp_path,
        two_last_generator_checkpoints, two_last_discriminator_checkpoints, use_discriminator,
        retain_one_checkpoint, num_gpus):
    exp_name = f"{exp_path}/g_{steps:08d}.pth"
    save_checkpoint(
        exp_name,
        {
            'generator': (generator.module if num_gpus > 1 else generator).state_dict(),
            'optim_g': optim_g.state_dict(),
            'steps': steps,
            'epoch': epoch})
    # Delete the previous generator checkpoint
    if two_last_generator_checkpoints[0] is not None and retain_one_checkpoint:
        os.remove(two_last_generator_checkpoints[0])
    two_last_generator_checkpoints = [two_last_generator_checkpoints[1], exp_name]

    if use_discriminator:
        exp_name = f"{exp_path}/do_{steps:08d}.pth"
        save_checkpoint(
            exp_name,
            {
                'discriminator': (discriminator.module if num_gpus > 1 else discriminator).state_dict(),
                'optim_g': optim_g.state_dict(),
                'optim_d': optim_d.state_dict(),
                'steps': steps,
                'epoch': epoch})
        # Delete the previous discriminator checkpoint
        if two_last_discriminator_checkpoints[0] is not None:
            os.remove(two_last_discriminator_checkpoints[0])
        two_last_discriminator_checkpoints = [two_last_discriminator_checkpoints[1], exp_name]

    return two_last_generator_checkpoints, two_last_discriminator_checkpoints


def validation_step(
        generator, simulator, reverberation_times, validation_loader,
        steps, n_fft, hop_size, win_size, compress_factor, sw, start_b, device, cfg):
    torch.cuda.empty_cache()
    audios_r, audios_g = [], []
    val_mag_err_tot = 0
    val_pha_err_tot = 0
    val_com_err_tot = 0
    with torch.no_grad():
        for j, batch in enumerate(validation_loader):
            clean_audio, clean_mag, clean_pha, clean_com, \
                noisy_audio, noisy_mag, noisy_pha, _norm_factor = batch  # [B, 1, F, T], F = nfft // 2+ 1, T = nframes

            clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
            noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))

            clean_mag = torch.autograd.Variable(clean_mag.to(device, non_blocking=True))
            clean_pha = torch.autograd.Variable(clean_pha.to(device, non_blocking=True))
            clean_com = torch.autograd.Variable(clean_com.to(device, non_blocking=True))

            noisy_mag = torch.autograd.Variable(noisy_mag.to(device, non_blocking=True))
            noisy_pha = torch.autograd.Variable(noisy_pha.to(device, non_blocking=True))

            # TODO: without t60 I can process the signals through the primary path beforehead
            t60 = randomize_reverberation_time(reverberation_times)
            # Process noise through primary path
            noisy_audio = process_signals_through_primary_path(noisy_audio, simulator, t60)
            # Process the generated signal through secondary path
            mag_g, pha_g, com_g = generator(noisy_mag, noisy_pha)
            audio_g = mag_phase_istft(mag_g, pha_g, n_fft, hop_size, win_size, compress_factor)
            audio_g = process_signals_through_secondary_path(
                audio_g, simulator, t60, sef_factor="random")
            audio_g = audio_g + noisy_audio

            # STFT the generated audio
            mag_g, pha_g, com_g = mag_phase_stft(audio_g, n_fft, hop_size, win_size, compress_factor)

            audios_r += torch.split(clean_audio, 1, dim=0)  # [1, T] * B
            audios_g += torch.split(audio_g, 1, dim=0)

            val_mag_err_tot += F.mse_loss(clean_mag, mag_g).item()
            val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g, cfg)
            val_pha_err_tot += (val_ip_err + val_gd_err + val_iaf_err).item()
            val_com_err_tot += F.mse_loss(clean_com, com_g).item()

        val_mag_err = val_mag_err_tot / (j+1)
        val_pha_err = val_pha_err_tot / (j+1)
        val_com_err = val_com_err_tot / (j+1)
        val_pesq_score = pesq_score(audios_r, audios_g, cfg).item()
        print('Steps : {:d}, PESQ Score: {:4.3f}, s/b : {:4.3f}'.format(
            steps, val_pesq_score, time.time() - start_b))
        sw.add_scalar("Validation/PESQ Score", val_pesq_score, steps)
        sw.add_scalar("Validation/Magnitude Loss", val_mag_err, steps)
        sw.add_scalar("Validation/Phase Loss", val_pha_err, steps)
        sw.add_scalar("Validation/Complex Loss", val_com_err, steps)

        return val_pesq_score, val_mag_err, val_pha_err


def train(rank, args, cfg):
    num_gpus = cfg['env_setting']['num_gpus']
    n_fft, hop_size, win_size = \
        (cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size'])
    compress_factor = cfg['model_cfg']['compress_factor']
    batch_size = cfg['training_cfg']['batch_size'] // cfg['env_setting']['num_gpus']
    if num_gpus >= 1:
        # initialize_process_group(cfg, rank)
        device = torch.device('cuda:{:d}'.format(rank))
    else:
        raise RuntimeError("Mamba needs GPU acceleration")

    # Create simulator
    reverberation_times, simulator = get_simulator(cfg, device)

    # Check if discriminator is needed
    use_discriminator = cfg['training_cfg']['loss']['metric'] != 0

    # Create models and load checkpoints
    generator, discriminator, optim_g, optim_d, scheduler_g, scheduler_d, \
        steps, last_epoch, last_generator_checkpoint, last_discriminator_checkpoint = \
        get_model_and_load_checkpoints(rank, device, num_gpus, args, use_discriminator, cfg)

    # Create trainset and train_loader
    trainset = create_dataset(cfg, train=True, split=True, device=device)
    train_loader = create_dataloader(trainset, cfg, train=True)

    if not use_discriminator:
        discriminator = None
        optim_d = None

    # Create validset and validation_loader if rank is 0
    if rank == 0:
        # split=True was foreced since the noise and the generated audio were not in the same length
        # (generator cuts some of the audio output)
        validset = create_dataset(cfg, train=False, split=True, device=device)
        validation_loader = create_dataloader(validset, cfg, train=False)
        sw = SummaryWriter(os.path.join(args.exp_path, 'logs'))

    # Start training
    generator.train()
    if use_discriminator:
        discriminator.train()

    # Load previous best PESQ if resuming training
    best_pesq_file = f"{args.exp_path}/best_pesq.log"
    if os.path.exists(best_pesq_file):
        with open(best_pesq_file, "r") as f:
            best_pesq, best_pesq_step = f.readline().strip().split()
            best_pesq = float(best_pesq)  # Convert to float
            best_pesq_step = int(best_pesq_step)  # Convert to int
            print("read best pesq from file: ", best_pesq, best_pesq_step)
    else:
        best_pesq, best_pesq_step = 0.0, 0  # Initialize if no previous record
    # Will save always the last checkpoint and the one before it
    two_last_generator_checkpoints, two_last_discriminator_checkpoints = \
        [None, last_generator_checkpoint], [None, last_discriminator_checkpoint]
    for epoch in range(max(0, last_epoch), cfg['training_cfg']['training_epochs']):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch+1))

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            clean_audio, clean_mag, clean_pha, clean_com, \
                noisy_audio, noisy_mag, noisy_pha, _norm_factor = batch  # [B, 1, F, T], F = nfft // 2+ 1, T = nframes

            clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
            noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))

            clean_mag = torch.autograd.Variable(clean_mag.to(device, non_blocking=True))
            clean_pha = torch.autograd.Variable(clean_pha.to(device, non_blocking=True))
            clean_com = torch.autograd.Variable(clean_com.to(device, non_blocking=True))

            noisy_mag = torch.autograd.Variable(noisy_mag.to(device, non_blocking=True))
            noisy_pha = torch.autograd.Variable(noisy_pha.to(device, non_blocking=True))
            one_labels = torch.ones(batch_size).to(device, non_blocking=True)

            t60 = randomize_reverberation_time(
                reverberation_times)  # TODO: without t60 I can process the signals through the primary path beforehead
            # Process noise through primary path
            noisy_audio = process_signals_through_primary_path(noisy_audio, simulator, t60)
            # Process the generated signal through secondary path
            mag_g, pha_g, com_g = generator(noisy_mag, noisy_pha)
            audio_g = mag_phase_istft(mag_g, pha_g, n_fft, hop_size, win_size, compress_factor)
            audio_g = process_signals_through_secondary_path(audio_g, simulator, t60, sef_factor="random")
            audio_g = audio_g + noisy_audio

            # STFT the generated audio
            mag_g, pha_g, com_g = mag_phase_stft(audio_g, n_fft, hop_size, win_size, compress_factor)

            # PESQ Score
            audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
            batch_pesq_score = batch_pesq(audio_list_r, audio_list_g, cfg)

            # Discriminator
            # ------------------------------------------------------- #
            if use_discriminator:
                optim_d.zero_grad()
                metric_r = discriminator(clean_mag, clean_mag)
                metric_g = discriminator(clean_mag, mag_g.detach())
                loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())

                if batch_pesq_score is not None:
                    loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
                else:
                    loss_disc_g = 0

                loss_disc_all = loss_disc_r + loss_disc_g

                loss_disc_all.backward()
                optim_d.step()
            # ------------------------------------------------------- #

            # Generator
            # ------------------------------------------------------- #
            optim_g.zero_grad()

            # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
            # L2 Magnitude Loss
            if args.new_loss:
                loss_mag = F.l1_loss(clean_mag, mag_g) + F.mse_loss(clean_mag, mag_g)
            else:
                loss_mag = F.mse_loss(clean_mag, mag_g)
            # Anti-wrapping Phase Loss
            loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g, cfg)
            loss_pha = loss_ip + loss_gd + loss_iaf
            # L2 Complex Loss
            loss_com = F.mse_loss(clean_com, com_g) * 2
            # Time Loss
            if args.new_loss:
                loss_time = F.l1_loss(clean_audio, audio_g) + F.mse_loss(clean_audio, audio_g)
            else:
                loss_time = F.l1_loss(clean_audio, audio_g)
            # Metric Loss
            if use_discriminator:
                metric_g = discriminator(clean_mag, mag_g)
                loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
            else:
                loss_metric = 0
            # Consistancy Loss
            _, _, rec_com = mag_phase_stft(audio_g, n_fft, hop_size, win_size, compress_factor, addeps=True)
            loss_con = F.mse_loss(com_g, rec_com) * 2

            loss_gen_all = (
                loss_metric * cfg['training_cfg']['loss']['metric'] +
                loss_mag * cfg['training_cfg']['loss']['magnitude'] +
                loss_pha * cfg['training_cfg']['loss']['phase'] +
                loss_com * cfg['training_cfg']['loss']['complex'] +
                loss_time * cfg['training_cfg']['loss']['time'] +
                loss_con * cfg['training_cfg']['loss']['consistancy']
            )

            loss_gen_all.backward()
            optim_g.step()
            # ------------------------------------------------------- #

            if rank == 0:
                # STDOUT logging
                if steps % cfg['env_setting']['stdout_interval'] == 0:
                    with torch.no_grad():
                        if use_discriminator:
                            metric_error = F.mse_loss(metric_g.flatten(), one_labels).item()
                        else:
                            # Dummy value in case discriminator is not used
                            metric_error = 0
                        if args.new_loss:
                            mag_error = F.l1_loss(clean_mag, mag_g).item() + F.mse_loss(clean_mag, mag_g).item()
                        else:
                            mag_error = F.mse_loss(clean_mag, mag_g).item()
                        ip_error, gd_error, iaf_error = phase_losses(clean_pha, pha_g, cfg)
                        pha_error = (ip_error + gd_error + iaf_error).item()
                        com_error = F.mse_loss(clean_com, com_g).item()
                        if args.new_loss:
                            time_error = \
                                F.l1_loss(clean_audio, audio_g).item() + F.mse_loss(clean_audio, audio_g).item()
                        else:
                            time_error = F.l1_loss(clean_audio, audio_g).item()
                        con_error = F.mse_loss(com_g, rec_com).item()
                        if not use_discriminator:
                            # Dummy value in case discriminator is not used
                            loss_disc_all = 0
                        print(
                            'Steps : {:d}, Gen Loss: {:4.3f}, Disc Loss: {:4.3f}, Metric Loss: {:4.3f}, '
                            'Mag Loss: {:4.3f}, Pha Loss: {:4.3f}, Com Loss: {:4.3f}, Time Loss: {:4.3f},'
                            'Cons Loss: {:4.3f}, s/b : {:4.3f}'.format(
                                steps, loss_gen_all, loss_disc_all, metric_error, mag_error,
                                pha_error, com_error, time_error, con_error, time.time() - start_b
                            )
                        )

                # Checkpointing
                if steps % cfg['env_setting']['checkpoint_interval'] == 0 and steps != 0:
                    two_last_generator_checkpoints, two_last_discriminator_checkpoints = checkpoiting(
                        generator, discriminator, optim_g, optim_d, steps, epoch, args.exp_path,
                        two_last_generator_checkpoints, two_last_discriminator_checkpoints, use_discriminator,
                        cfg['env_setting']['retain_one_checkpoint'], num_gpus)

                # Tensorboard summary logging
                if steps % cfg['env_setting']['summary_interval'] == 0:
                    sw.add_scalar("Training/Generator Loss", loss_gen_all, steps)
                    if use_discriminator:
                        sw.add_scalar("Training/Discriminator Loss", loss_disc_all, steps)
                        sw.add_scalar("Training/Metric Loss", metric_error, steps)
                    sw.add_scalar("Training/Magnitude Loss", mag_error, steps)
                    sw.add_scalar("Training/Phase Loss", pha_error, steps)
                    sw.add_scalar("Training/Complex Loss", com_error, steps)
                    sw.add_scalar("Training/Time Loss", time_error, steps)
                    sw.add_scalar("Training/Consistancy Loss", con_error, steps)

                # If NaN happend in training period, RaiseError
                if torch.isnan(loss_gen_all).any():
                    raise ValueError("NaN values found in loss_gen_all")

                # Validation
                if steps % cfg['env_setting']['validation_interval'] == 0 and steps != 0:
                    generator.eval()
                    val_pesq_score, val_mag_err, val_pha_err = validation_step(
                        generator, simulator, reverberation_times, validation_loader,
                        steps, n_fft, hop_size, win_size, compress_factor, sw, start_b, device, cfg)
                    generator.train()

                    # Print best validation PESQ score in terminal
                    if val_pesq_score >= best_pesq:
                        best_pesq = val_pesq_score
                        best_pesq_step = steps
                        # Save the best PESQ score and step to a log file
                        with open(best_pesq_file, "w") as f:
                            f.write(f"{best_pesq} {best_pesq_step}\n")
                        # Save the best model
                        exp_name = f"{args.exp_path}/best_g.pth"
                        save_checkpoint(
                            exp_name,
                            {
                                'generator': (generator.module if num_gpus > 1 else generator).state_dict(),
                                'optim_g': optim_g.state_dict(),
                                'steps': steps,
                                'epoch': epoch})
                        if use_discriminator:
                            exp_name = f"{args.exp_path}/best_do.pth"
                            save_checkpoint(
                                exp_name,
                                {
                                    'discriminator': (
                                        discriminator.module if num_gpus > 1 else discriminator).state_dict(),
                                    'optim_g': optim_g.state_dict(),
                                    'optim_d': optim_d.state_dict(),
                                    'steps': steps,
                                    'epoch': epoch})
                    print(
                        f"valid: PESQ {val_pesq_score}, Mag_loss {val_mag_err}, Phase_loss {val_pha_err}. "
                        f"Best_PESQ: {best_pesq} at step {best_pesq_step}")
            steps += 1

        scheduler_g.step()
        if use_discriminator:
            scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))


# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_folder', help='Path to the folder where the experiment will be saved')
    parser.add_argument('--exp_name', help='Name of the experiment')
    parser.add_argument('--config', help='Path to the config file')
    parser.add_argument('--avoid_SEMamba_base', default=False, action='store_true',
                        help="Avoid fine-tuning SEMamba base")
    parser.add_argument('--new_loss', default=True, action='store_false', help='Add flag to avoid using new loss')
    args = parser.parse_args()

    # Check if the experiment folder exists and create it if not
    args.exp_path = os.path.join(args.exp_folder, args.exp_name)
    if not os.path.exists(args.exp_path):
        os.makedirs(args.exp_path)
    # Copy SEMamba base checkpoint to the experiment folder if using SEMamba base for fine-tuning
    if not args.avoid_SEMamba_base:
        os.system(f"cp {args.exp_folder}/g_00000001.pth {args.exp_path}")

    # Load the config file
    cfg = load_config(args.config)
    seed = cfg['env_setting']['seed']

    initialize_seed(seed)
    build_env(args.config, 'config.yaml', args.exp_path)
    if torch.cuda.is_available():
        num_available_gpus = torch.cuda.device_count()
        print(f"Number of GPUs available: {num_available_gpus}")
        print_gpu_info(num_available_gpus, cfg)
    else:
        print("CUDA is not available.")

    train(0, args, cfg)


if __name__ == '__main__':
    main()
