#from accelerate import Accelerator
import torch
import torch.nn.functional as F
from torch.amp import GradScaler
from torch import autocast
from utils.checkpoints import save_checkpoint, get_latest_checkpoint, load_checkpoint, cleanup_old_checkpoints
from utils.logger import log_to_csv
from tqdm import tqdm

def train_base_model(model, dataloader, config, backbone_name, latent_dim, experiment_name, seed=42, checkpoint_root='checkpoints'):
    """Train the base model with the specified configuration.
    Args:
        model: The model to be trained.
        dataloader: The dataloader for training data.
        config: Configuration dictionary containing training parameters.
        backbone_name: Name of the backbone model.
        latent_dim: Latent dimension of the model.
        experiment_name: Name of the experiment for logging and checkpointing.
        seed: Random seed for reproducibility.
        checkpoint_root: Root directory for saving checkpoints.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    loss_name = config['loss']
    checkpoint_dir = f"{checkpoint_root}/base/{experiment_name}/{loss_name}/"
    prefix = f"{backbone_name}_latent{latent_dim}_weight{config['fairness_weight']}_n{len(dataloader.dataset)}_seed{seed}_"
    resume_ckpt = get_latest_checkpoint(checkpoint_dir, prefix)

    start_epoch = 0
    match config['optimizer']:
        case 'Adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
        case 'AdamW':
            optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])
        case _:
            raise NotImplementedError('Optimizer not supported')
    scaler = GradScaler(device.type)

    match loss_name:
        case 'cross_entropy'|'log'|'logistic':
            criterion = F.cross_entropy
            class_weights = torch.tensor([1.0, config['fairness_weight']], device=device)
        case _:
            raise NotImplementedError('Loss not supported')


    if resume_ckpt:
        print(f"Resuming from {resume_ckpt}")
        checkpoint = load_checkpoint(model, resume_ckpt, part='base')
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1

    model.train()
    for epoch in tqdm(range(start_epoch, config['train_epochs'])):
        for x, y in tqdm(dataloader, leave=False):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            
            with autocast(device_type = device.type):
                preds = model(x)
                loss = criterion(preds, y,weight=class_weights)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # if accelerator.is_main_process:
        ckpt_path = f"{checkpoint_dir}{prefix}epoch{epoch}.pt"
        save_checkpoint({
            'backbone': model.backbone.state_dict(),
            'projector': model.projector.state_dict(),
            'classifier': model.classifier.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'loss': loss.item()
        }, ckpt_path)
        cleanup_old_checkpoints(checkpoint_dir, prefix, latest_epoch=epoch, final_epoch=config['train_epochs'] - 1)

        log_to_csv(f"logs/{experiment_name}/base_train_log.csv", {
            'epoch': epoch,
            'backbone': backbone_name,
            'latent_dim': latent_dim,
            'fairness_weight': config['fairness_weight'],
            'loss': loss.item()
        })
