import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.cuda.amp as amp
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm
import utils as utils

def train_atomecg(model_atomecg, 
                  train_dataloader, 
                  lr, 
                  epochs, 
                  device, 
                  aug_gaussian=False, 
                  log_path=None, 
                  save_model_dir=None,
                  checkpoint_path=None):
    
    # Initialize metrics trackers
    train_res_recon_error = []
    train_vq_loss = []
    train_res_perplexity = {}
    
    atomecg_optimizer = utils.get_optimizer('adamw', model_atomecg, lr)
    model_atomecg.train()
    
    writer = SummaryWriter(log_path) if log_path else None
    
    start_epoch = 0
    start_batch = 0

    # Load checkpoint
    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model_atomecg.load_state_dict(checkpoint['model_atomecg_state_dict'])
        atomecg_optimizer.load_state_dict(checkpoint['atomecg_optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_batch = checkpoint['batch']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}, batch {start_batch}")

    for epoch in tqdm(range(start_epoch, epochs), desc="Pre-Training Progress"):
        train_dataloader_iter = iter(train_dataloader)
        
        # Skip batches that have already been processed if resuming
        if epoch == start_epoch:
            for _ in range(start_batch):
                next(train_dataloader_iter)

        for batch_idx, data_batch in enumerate(train_dataloader_iter, start=start_batch if epoch == start_epoch else 0):
            data = data_batch.float().to(device)
            
            if aug_gaussian:
                data = utils.add_noise(data, noise_level=0.01)

            atomecg_optimizer.zero_grad()
            
            vq_loss, x_recon, _, _, perplexities = model_atomecg(data)
            recon_error = F.mse_loss(x_recon, data)
            loss = recon_error + vq_loss
            
            # Handle potential NaN loss
            if torch.isnan(loss):
                print(f"Warning: NaN loss detected at epoch {epoch}, batch {batch_idx}. Skipping step.")
                torch.cuda.empty_cache()
                continue

            loss.backward()
            atomecg_optimizer.step()
            
            # Store metrics
            train_res_recon_error.append(recon_error.item())
            train_vq_loss.append(vq_loss.item())
            for i, perplexity in enumerate(perplexities):
                train_res_perplexity.setdefault(i, []).append(perplexity.item())

            # Logging
            if writer and (batch_idx + 1) % 600 == 0:
                avg_recon_error = np.mean(train_res_recon_error[-100:])
                avg_vq_loss = np.mean(train_vq_loss[-100:])
                
                writer.add_scalar('Loss/Reconstruction', avg_recon_error, epoch * len(train_dataloader) + batch_idx)
                writer.add_scalar('Loss/VQ', avg_vq_loss, epoch * len(train_dataloader) + batch_idx)
                
                print(f"\nEpoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}]")
                print(f"Reconstruction Error: {avg_recon_error:.4f}")
                print(f"VQ Loss: {avg_vq_loss:.4f}")

                for i, perplexity_list in train_res_perplexity.items():
                    avg_perplexity = np.mean(perplexity_list[-100:])
                    writer.add_scalar(f'Metrics/Perplexity_Head_{i}', avg_perplexity, epoch * len(train_dataloader) + batch_idx)
                    print(f'Perplexity (Head {i}): {avg_perplexity:.3f}')
                print()
            
            # Save checkpoint periodically
            if checkpoint_path and (batch_idx + 1) % 13500 == 0:
                torch.save({
                    'epoch': epoch,
                    'batch': batch_idx + 1,
                    'model_atomecg_state_dict': model_atomecg.state_dict(),
                    'atomecg_optimizer_state_dict': atomecg_optimizer.state_dict(),
                }, checkpoint_path)
                print(f"Checkpoint saved at epoch {epoch}, batch {batch_idx + 1}")

        start_batch = 0  # Reset for the next epoch
        if save_model_dir and (epoch + 1) % 2 == 0:
            torch.save(model_atomecg.state_dict(), f"{save_model_dir}/atomecg_epoch_{epoch+1}.pth")

    if writer:
        writer.close()
    if save_model_dir:
        torch.save(model_atomecg.state_dict(), f'{save_model_dir}/atomecg_final.pth')


def finetune_atomecg(signal_length,
                   model_atomecg, 
                   train_dataloader, 
                   lr, 
                   epochs, 
                   device, 
                   log_path=None, 
                   save_model_dir=None,
                   checkpoint_path=None):
    
    # Initialize metrics trackers
    train_res_recon_error = []
    train_mask_mae = []
    train_mask_rmse = []
    train_vq_loss = []
    train_res_perplexity = {}
    
    atomecg_optimizer = utils.get_optimizer('adamw', model_atomecg, lr)
    model_atomecg.train()
    
    writer = SummaryWriter(log_path) if log_path else None
    
    start_epoch = 0
    start_batch = 0

    # Load checkpoint
    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model_atomecg.load_state_dict(checkpoint['model_atomecg_state_dict'])
        atomecg_optimizer.load_state_dict(checkpoint['atomecg_optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_batch = checkpoint['batch']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}, batch {start_batch}")

    for epoch in tqdm(range(start_epoch, epochs), desc="Fine-tuning Progress"):
        train_dataloader_iter = iter(train_dataloader)

        if epoch == start_epoch:
            for _ in range(start_batch):
                next(train_dataloader_iter)
        
        for batch_idx, data_batch in enumerate(train_dataloader_iter, start=start_batch if epoch == start_epoch else 0):
            (data, _) = data_batch
            data = data.float()
            if data.shape[-1] == signal_length:
                data = data.permute(0, 2, 1)    
            data = data.to(device)
                               
            atomecg_optimizer.zero_grad()
            
            vq_loss, x_recon, _, _, perplexities = model_atomecg(data)
            
            # Main reconstruction loss
            recon_error = F.mse_loss(x_recon, data)
            loss = recon_error + vq_loss
            
            # Additional loss for masked reconstruction if applicable
            # mask_ratio = model_atomecg._encoder.config.random_mask_ratio
            # if mask_ratio > 0:
            #     z = model_atomecg._encoder(data) # Re-run encoder to get the mask
            #     mask = z.mask.unsqueeze(-1).repeat(1, 1, 1, model_atomecg.patch_length)
                
            #     patchifier = model_atomecg._encoder.patchifier
            #     original_patches = patchifier(data)
            #     recon_patches = patchifier(x_recon)
                
            #     # Calculate loss only on the masked patches
            #     prediction_mae = F.l1_loss(recon_patches[mask], original_patches[mask])
            #     prediction_mse = F.mse_loss(recon_patches[mask], original_patches[mask])
                
            #     loss += torch.sqrt(prediction_mse) + prediction_mae

            # Handle potential NaN loss
            if torch.isnan(loss):
                print(f"Warning: NaN loss detected at epoch {epoch}, batch {batch_idx}. Skipping step.")
                torch.cuda.empty_cache()
                continue

            loss.backward()
            atomecg_optimizer.step()
            
            # Store metrics
            train_res_recon_error.append(recon_error.item())
            train_vq_loss.append(vq_loss.item())
            # if mask_ratio > 0:
            #     train_mask_mae.append(prediction_mae.item())
            #     train_mask_rmse.append(torch.sqrt(prediction_mse).item())
            
            for i, perplexity in enumerate(perplexities):
                train_res_perplexity.setdefault(i, []).append(perplexity.item())

            # Logging
            if writer and (batch_idx + 1) % 5 == 0:
                global_step = epoch * len(train_dataloader) + batch_idx
                writer.add_scalar('Loss/Reconstruction', np.mean(train_res_recon_error[-100:]), global_step)
                writer.add_scalar('Loss/VQ', np.mean(train_vq_loss[-100:]), global_step)
                
                print(f"\nEpoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}]")
                print(f"Reconstruction Error: {np.mean(train_res_recon_error[-100:]):.4f}")
                print(f"VQ Loss: {np.mean(train_vq_loss[-100:]):.4f}")

                # if mask_ratio > 0:
                #     writer.add_scalar('Loss/Masked_MAE', np.mean(train_mask_mae[-100:]), global_step)
                #     writer.add_scalar('Loss/Masked_RMSE', np.mean(train_mask_rmse[-100:]), global_step)
                #     print(f"Masked MAE: {np.mean(train_mask_mae[-100:]):.6f}")
                #     print(f"Masked RMSE: {np.mean(train_mask_rmse[-100:]):.6f}")

                for i, perplexity_list in train_res_perplexity.items():
                    avg_perplexity = np.mean(perplexity_list[-100:])
                    writer.add_scalar(f'Metrics/Perplexity_Head_{i}', avg_perplexity, global_step)
                    print(f'Perplexity (Head {i}): {avg_perplexity:.3f}')
                print()
            
            # Save checkpoint
            if checkpoint_path and (batch_idx + 1) % 2000 == 0:
                torch.save({
                    'epoch': epoch,
                    'batch': batch_idx + 1,
                    'model_atomecg_state_dict': model_atomecg.state_dict(),
                    'atomecg_optimizer_state_dict': atomecg_optimizer.state_dict(),
                }, checkpoint_path)
                print(f"Checkpoint saved at epoch {epoch}, batch {batch_idx + 1}")

        start_batch = 0
                        
        if save_model_dir and (epoch + 1) % 15 == 0:
            os.makedirs(save_model_dir, exist_ok=True)
            torch.save(model_atomecg.state_dict(), f"{save_model_dir}/atomecg_finetuned_epoch_{epoch+1}.pth")
    
    if writer:
        writer.close()
    if save_model_dir:
        torch.save(model_atomecg.state_dict(), f'{save_model_dir}/atomecg_finetuned_final.pth')