import os
import time
import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt

import config as cfg
from data_utils import get_dataloaders
from model import EnhancedCVAE
from loss_utils import loss_function, adaptive_kl_weight as compute_adaptive_kl_weight

def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, save_path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss
    }
    torch.save(checkpoint, os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth'))
    print(f'Checkpoint saved at epoch {epoch}')

def train_epoch(model, train_loader, optimizer, device,
                kl_weight=1.0, epoch=None, total_epochs=100):
    model.train()
    train_metrics = {
        'loss': 0,
        'recon_loss': 0,
        'kl_loss': 0,
        'adaptive_kl_weight': kl_weight
    }
    total_samples = 0
    for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
        input_seq, target_seq = input_seq.to(device), target_seq.to(device)
        optimizer.zero_grad()
        recon_batch, mu, log_var = model(target_seq, input_seq)
        loss, recon, kl, adaptive_weight = loss_function(
            recon_batch, target_seq, mu, log_var,
            kl_weight=kl_weight,
            epoch=epoch,
            total_epochs=total_epochs
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        batch_size = input_seq.size(0)
        total_samples += batch_size
        train_metrics['loss'] += loss.item() * batch_size
        train_metrics['recon_loss'] += recon.item() * batch_size
        train_metrics['kl_loss'] += kl.item() * batch_size
    for key in ['loss', 'recon_loss', 'kl_loss']:
        train_metrics[key] /= total_samples
    return train_metrics

def validate(model, test_loader, device, kl_weight=1.0, epoch=None, total_epochs=100):
    model.eval()
    val_metrics = {
        'loss': 0,
        'recon_loss': 0,
        'kl_loss': 0,
        'adaptive_kl_weight': kl_weight
    }
    total_samples = 0
    with torch.no_grad():
        for input_seq, target_seq in test_loader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            recon_batch, mu, log_var = model(target_seq, input_seq)
            loss, recon, kl, adaptive_weight = loss_function(
                recon_batch, target_seq, mu, log_var,
                kl_weight=kl_weight,
                epoch=epoch,
                total_epochs=total_epochs
            )
            batch_size = input_seq.size(0)
            total_samples += batch_size
            val_metrics['loss'] += loss.item() * batch_size
            val_metrics['recon_loss'] += recon.item() * batch_size
            val_metrics['kl_loss'] += kl.item() * batch_size
    for key in ['loss', 'recon_loss', 'kl_loss']:
        val_metrics[key] /= total_samples
    return val_metrics

class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f'Early stopping triggered after {self.patience} epochs without improvement.')

def train(
    model,
    train_loader,
    test_loader,
    optimizer,
    scheduler,
    device,
    epochs,
    kl_weight,
    save_path,
    save_interval=100,
    patience=5
):
    os.makedirs(save_path, exist_ok=True)
    log_file = os.path.join(save_path, 'training_log.txt')
    results_dir = os.path.join(save_path, 'results')
    os.makedirs(results_dir, exist_ok=True)
    with open(log_file, 'w') as f:
        f.write("Epoch\tTrain Loss\tTrain Recon\tTrain KL\t"
                "Val Loss\tVal Recon\tVal KL\tLR\tKL Weight\n")
    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    print("Initial model evaluation before training...")
    initial_val_metrics = validate(model, test_loader, device, kl_weight)
    print(f"Initial validation loss: {initial_val_metrics['loss']:.4f}")
    for epoch in range(1, epochs + 1):
        adaptive_kl_weight = compute_adaptive_kl_weight(
            base_weight=kl_weight,
            epoch=epoch,
            total_epochs=epochs
        )
        train_metrics = train_epoch(
            model,
            train_loader,
            optimizer,
            device,
            kl_weight=adaptive_kl_weight,
            epoch=epoch
        )
        val_metrics = validate(
            model,
            test_loader,
            device,
            kl_weight=adaptive_kl_weight,
            epoch=epoch,
            total_epochs=epochs
        )
        current_lr = optimizer.param_groups[0]['lr']
        train_losses.append(train_metrics['loss'])
        val_losses.append(val_metrics['loss'])
        with open(log_file, 'a') as f:
            f.write(f"{epoch}\t"
                    f"{train_metrics['loss']:.4f}\t"
                    f"{train_metrics['recon_loss']:.4f}\t"
                    f"{train_metrics['kl_loss']:.4f}\t"
                    f"{val_metrics['loss']:.4f}\t"
                    f"{val_metrics['recon_loss']:.4f}\t"
                    f"{val_metrics['kl_loss']:.4f}\t"
                    f"{current_lr:.6f}\t"
                    f"{adaptive_kl_weight:.4f}\n")
        print(f'Epoch {epoch}/{epochs}: '
              f'Train Loss={train_metrics["loss"]:.4f} '
              f'(Recon={train_metrics["recon_loss"]:.4f}, '
              f'KL={train_metrics["kl_loss"]:.4f}), '
              f'Val Loss={val_metrics["loss"]:.4f} '
              f'(Recon={val_metrics["recon_loss"]:.4f}, '
              f'KL={val_metrics["kl_loss"]:.4f}), '
              f'LR={current_lr:.6f}, '
              f'KL Weight={adaptive_kl_weight:.4f}')
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(val_metrics['loss'])
        else:
            scheduler.step()
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            best_model_path = os.path.join(save_path, 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss
            }, best_model_path)
            print(f"Best model saved, validation loss: {best_val_loss:.4f}")
        early_stopping(val_metrics['loss'])
        if early_stopping.early_stop:
            print("Early stopping triggered. Training terminated.")
            break
        if epoch % save_interval == 0:
            checkpoint_path = os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['loss']
            }, checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")
    try:
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.title('Training and Validation Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'loss_curves.png'))
        plt.close()
    except Exception as e:
        print(f"Error plotting loss curves: {e}")
    return train_losses, val_losses
