import json
import logging
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import psutil  # For memory monitoring

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter
def print(*args, **kwargs):
    logging.info(" ".join(map(str, args)))
    
def print_memory_usage():
    """Print current memory usage"""
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    print(f"Memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
    if torch.cuda.is_available():
        print(f"GPU memory: {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB allocated, "
              f"{torch.cuda.memory_reserved() / 1024 / 1024:.2f} MB reserved")
    
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(42)


def train(model, 
          train_loader, 
          val_loader, 
          epochs, 
          lr, 
          factor, 
          save_dir, 
          writer=None,
          history_predictions = {},
          patience=20,
          epoch_start=0):
    
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)
    os.makedirs(save_dir, exist_ok=True)
    
    # Loss function and optimizer setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                           mode='max', 
                                                           factor=factor, 
                                                           patience=10)
    
    # State initialization
    no_improvement_count = 0
    best_val_accs = []
    best_model_state = None
    
    # TensorBoard writer
    if writer:
        writer.add_text('Training Log', f"Training started with lr={lr}, epochs={epochs}")


    for epoch in range(epochs):
        print(f"----------------Epoch {epoch + 1}/{epochs}----------------")
        print_memory_usage()  # Check memory before epoch starts
        model.train()  # Set model to train mode
        
        running_loss = 0.0
        correct = 0
        total = 0

        # epoch_iterator = tqdm(train_loader, desc="Training", ncols=100)
        epoch_dynamics = []  # Initialize a list to store training dynamics
        # print(f"Starting epoch {epoch + 1} with {len(train_loader)} batches")
        for step, (inputs, labels, sample_ids) in enumerate(train_loader):
            try:
                # print(f"Processing batch {step + 1}/{len(train_loader)}, batch size: {inputs.size(0)}")
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Track history predictions
                for i in range(labels.size(0)):
                    sample_id = sample_ids[i]
                    if sample_id not in history_predictions:
                        history_predictions[sample_id] = []
                    history_predictions[sample_id].append(predicted[i].item())

                # Collect training dynamics for this batch
                for i in range(labels.size(0)):
                    dynamics = {
                        "guid": sample_ids[i].item(),  # Sample ID
                        "logits_epoch_{}".format(epoch): outputs[i].cpu().detach().numpy().tolist(),  # Logits for current epoch
                        "gold": labels[i].item()  # Gold label (true class)
                    }
                    epoch_dynamics.append(dynamics)
                
                # print(f"Batch {step + 1} completed successfully")
                    
            except Exception as e:
                print(f"Error in batch {step + 1}: {str(e)}")
                import traceback
                traceback.print_exc()
                raise e
                
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

        # Validation step
        val_acc = validate(model, val_loader, device)
        
        if writer:
            writer.add_scalar('Accuracy/val', val_acc, epoch+epoch_start)
            writer.add_scalar('Loss/train', epoch_loss, epoch+epoch_start)
            writer.add_scalar('Accuracy/train', epoch_acc, epoch+epoch_start)
            writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch+epoch_start)
            writer.add_scalars('Accuracy', {'train': epoch_acc, 'val': val_acc}, epoch+epoch_start)
            writer.flush()
            

        scheduler.step(val_acc)
        
        # Save the training dynamics for this epoch to a jsonl file
        dynamics_dir = os.path.join(save_dir, "training_dynamics")
        os.makedirs(dynamics_dir, exist_ok=True)
        dynamics_file_path = os.path.join(dynamics_dir, f"training_dynamics_epoch_{epoch+1}.jsonl")
        
        with open(dynamics_file_path, 'a') as f:
            for entry in epoch_dynamics:
                json.dump(entry, f)
                f.write('\n')

        # Save model checkpoint if validation accuracy improves
        if val_acc > min(best_val_accs, default=-np.inf):  
            print(f"Saving top model with accuracy: {val_acc:.2f}%")
            if val_acc > max(best_val_accs, default=-np.inf):  # Best model
                best_model_state = model.state_dict()
            
            best_val_accs.append(val_acc)
            best_val_accs = sorted(best_val_accs, reverse=True)

            no_improvement_count = 0

            model_filename = f"best_model_acc_{val_acc:.2f}.pth"
            torch.save(model.state_dict(), os.path.join(save_dir, model_filename))

            if len(best_val_accs) > 3:
                worst_val_acc = best_val_accs[-1]
                worst_model_filename = f"best_model_acc_{worst_val_acc:.2f}.pth"
                worst_model_path = os.path.join(save_dir, worst_model_filename)
                best_val_accs = best_val_accs[:3]
                
                if os.path.exists(worst_model_path):
                    os.remove(worst_model_path)
                    print(f"Deleted old model with accuracy: {worst_val_acc:.2f}%")
        else:
            no_improvement_count += 1

        # Log the progress
        epoch_log = {
            "epoch": epoch,
            "train_loss": epoch_loss,
            "train_acc": epoch_acc,
            "val_acc": val_acc,
            "learning_rate": optimizer.param_groups[0]['lr']
        }

        with open(os.path.join(save_dir, "training_log.jsonl"), "a") as log_file:
            log_file.write(json.dumps(epoch_log) + "\n")
        
        # Early stopping based on validation performance
        if no_improvement_count >= patience:
            print(f"Validation accuracy has not improved for {patience} epochs. Stopping training.")
            break

    # Save the best model state after training
    if best_model_state:
        model.load_state_dict(best_model_state)
        print(f"Loading best model with validation accuracy: {max(best_val_accs):.2f}%")

    return model, epoch+1

def validate(model, val_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels, _ in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    print(f"Validation Accuracy: {val_acc:.2f}%")
    return val_acc

