import os
import time
import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt

import config as cfg
from data_utils import get_dataloaders
from model import EnhancedCVAE
from loss_utils import loss_function, adaptive_kl_weight as compute_adaptive_kl_weight

def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, save_path):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss
    }
    torch.save(checkpoint, os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth'))
    print(f'Checkpoint saved at epoch {epoch}')

def train_epoch(model, train_loader, optimizer, device,
                kl_weight=1.0, epoch=None, total_epochs=100):
    """Train one epoch"""
    model.train()
    train_metrics = {
        'loss': 0,
        'recon_loss': 0,
        'kl_loss': 0,
        'adaptive_kl_weight': kl_weight
    }
    
    total_samples = 0 
    
    for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
        input_seq, target_seq = input_seq.to(device), target_seq.to(device)
        
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = model(target_seq, input_seq)
        
        loss, recon, kl, adaptive_weight = loss_function(
            recon_batch, target_seq, mu, log_var,
            kl_weight=kl_weight,
            epoch=epoch,
            total_epochs=total_epochs
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        batch_size = input_seq.size(0)
        total_samples += batch_size
        
        train_metrics['loss'] += loss.item() * batch_size
        train_metrics['recon_loss'] += recon.item() * batch_size
        train_metrics['kl_loss'] += kl.item() * batch_size    
    
    for key in ['loss', 'recon_loss', 'kl_loss']:
        train_metrics[key] /= total_samples
    
    return train_metrics

def validate(model, test_loader, device, kl_weight=1.0, epoch=None, total_epochs=100):
    """Model validation function"""
    model.eval()
    val_metrics = {
        'loss': 0,
        'recon_loss': 0,
        'kl_loss': 0,
        'adaptive_kl_weight': kl_weight
    }
    
    total_samples = 0
    
    with torch.no_grad():
        for input_seq, target_seq in test_loader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            
            recon_batch, mu, log_var = model(target_seq, input_seq)
            
            loss, recon, kl, adaptive_weight = loss_function(
                recon_batch, target_seq, mu, log_var,
                kl_weight=kl_weight,
                epoch=epoch,
                total_epochs=total_epochs
            )
            
            batch_size = input_seq.size(0)
            total_samples += batch_size
            
            val_metrics['loss'] += loss.item() * batch_size
            val_metrics['recon_loss'] += recon.item() * batch_size
            val_metrics['kl_loss'] += kl.item() * batch_size
    
    for key in ['loss', 'recon_loss', 'kl_loss']:
        val_metrics[key] /= total_samples
    
    return val_metrics

class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f'Early stopping triggered after {self.patience} epochs without improvement.')

def train(
    model,
    train_loader,
    test_loader,
    optimizer,
    scheduler,
    device,
    epochs,
    kl_weight,
    save_path,
    save_interval=100,
    patience=5
):
    """Main function to train the model"""
    os.makedirs(save_path, exist_ok=True)
    log_file = os.path.join(save_path, 'training_log.txt')
    results_dir = os.path.join(save_path, 'results')
    os.makedirs(results_dir, exist_ok=True)
    
    with open(log_file, 'w') as f:
        f.write("Epoch\tTrain Loss\tTrain Recon\tTrain KL\t"
                "Val Loss\tVal Recon\tVal KL\tLR\tKL Weight\n")
    
    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    print("Initial model evaluation before training...")
    initial_val_metrics = validate(model, test_loader, device, kl_weight)
    print(f"Initial validation loss: {initial_val_metrics['loss']:.4f}")
    
    for epoch in range(1, epochs + 1):
        adaptive_kl_weight = compute_adaptive_kl_weight(
            base_weight=kl_weight,
            epoch=epoch,
            total_epochs=epochs
        )
        
        train_metrics = train_epoch(
            model,
            train_loader,
            optimizer,
            device,
            kl_weight=adaptive_kl_weight,
            epoch=epoch
        )
        
        val_metrics = validate(
            model,
            test_loader,
            device,
            kl_weight=adaptive_kl_weight,
            epoch=epoch,
            total_epochs=epochs
        )
        
        current_lr = optimizer.param_groups[0]['lr']
        
        train_losses.append(train_metrics['loss'])
        val_losses.append(val_metrics['loss'])
        
        with open(log_file, 'a') as f:
            f.write(f"{epoch}\t"
                    f"{train_metrics['loss']:.4f}\t"
                    f"{train_metrics['recon_loss']:.4f}\t"
                    f"{train_metrics['kl_loss']:.4f}\t"
                    f"{val_metrics['loss']:.4f}\t"
                    f"{val_metrics['recon_loss']:.4f}\t"
                    f"{val_metrics['kl_loss']:.4f}\t"
                    f"{current_lr:.6f}\t"
                    f"{adaptive_kl_weight:.4f}\n")
        
        print(f'Epoch {epoch}/{epochs}: '
              f'Train Loss={train_metrics["loss"]:.4f} '
              f'(Recon={train_metrics["recon_loss"]:.4f}, '
              f'KL={train_metrics["kl_loss"]:.4f}), '
              f'Val Loss={val_metrics["loss"]:.4f} '
              f'(Recon={val_metrics["recon_loss"]:.4f}, '
              f'KL={val_metrics["kl_loss"]:.4f}), '
              f'LR={current_lr:.6f}, '
              f'KL Weight={adaptive_kl_weight:.4f}')
        
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(val_metrics['loss'])
        else:
            scheduler.step()
        
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            best_model_path = os.path.join(save_path, 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss
            }, best_model_path)
            print(f"Best model saved, validation loss: {best_val_loss:.4f}")
        
        early_stopping(val_metrics['loss'])
        if early_stopping.early_stop:
            print("Early stopping triggered, training stopped.")
            break 
        
        if epoch % save_interval == 0:
            checkpoint_path = os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['loss']
            }, checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")
    
    try:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.title('Training and Validation Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'loss_curves.png'))
        plt.close()
    except Exception as e:
        print(f"Error while plotting loss curves: {e}")
    
    return train_losses, val_losses
