#from accelerate import Accelerator
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler
from torch import autocast
from utils.checkpoints import save_checkpoint, get_latest_checkpoint, cleanup_old_checkpoints
from utils.logger import log_to_csv
from tqdm import tqdm

def finetune_head(model, dataloader, config, weight, backbone_name, latent_dim, experiment_name, seed, checkpoint_root='checkpoints'):
    """Fine-tune the head of the model with the specified configuration.
    Args:
        model: The model to be fine-tuned.
        dataloader: The dataloader for training data.
        config: Configuration dictionary containing training parameters.
        weight: Fairness weight for the loss function.
        backbone_name: Name of the backbone model.
        latent_dim: Latent dimension of the model.
        experiment_name: Name of the experiment for logging and checkpointing.
        seed: Random seed for reproducibility.
        checkpoint_root: Root directory for saving checkpoints.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    classifier = nn.Linear(model.projector.out_features, config['num_classes']).to(device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=config['learning_rate'])
    loss_name = config['loss']
    checkpoint_dir = f"{checkpoint_root}/heads/{experiment_name}/{loss_name}/"
    prefix = f"{backbone_name}_latent{latent_dim}_weight{weight}_n{len(dataloader.dataset)}_seed{seed}_"
    resume_ckpt = get_latest_checkpoint(checkpoint_dir, prefix)

    match loss_name:
        case 'cross_entropy':
            criterion = F.cross_entropy
            class_weights = torch.tensor([1.0, weight], device=device)
        case 'square':
            criterion = lambda logit,y,weight: F.mse_loss(F.softmax(logit)[:,1],y,weight = weight)
        case _:
            raise NotImplementedError('Loss not supported')
    start_epoch = 0
    scaler = GradScaler(device.type)
    if resume_ckpt:
        print(f"Resuming head from {resume_ckpt}")
        checkpoint = torch.load(resume_ckpt)
        classifier.load_state_dict(checkpoint['classifier'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1


    for epoch in tqdm(range(start_epoch, config['finetune_epochs']),leave=False):
        for x, y in tqdm(dataloader,leave=False):
            # square loss only allows sample weights rather than class weights
            if loss_name == 'square':
                class_weights = torch.where(y==1,weight,1).to(device)

            with torch.no_grad():
                features = model.backbone(x.to(device))
                latent = model.projector(features)

            with autocast(device_type = device.type):
                logits = classifier(latent)
                loss = criterion(logits, y.to(device), weight=class_weights)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        ckpt_path = f"{checkpoint_dir}{prefix}epoch{epoch}.pt"
        save_checkpoint({
            'classifier': classifier.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'loss': loss.item(),
            'fairness_weight': weight
        }, ckpt_path)
        cleanup_old_checkpoints(checkpoint_dir, prefix, latest_epoch=epoch, final_epoch=config['finetune_epochs'] - 1)

        log_to_csv(f"logs/{experiment_name}/head_finetune_log.csv", {
            'epoch': epoch,
            'backbone': backbone_name,
            'latent_dim': latent_dim,
            'fairness_weight': weight,
            'n': len(dataloader.dataset),
            'loss': loss.item()
        })
    return classifier
