import torch
from src.utils.losses import *


def train_model(args, logger, model, train_loader, optimizer, epoch, device):
    if args.supervised:
        train_supervised(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "SDMI":
        train_SDMI(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "SimSiam":
        train_SimSiam(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "SimCLR":
        train_SimCLR(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "BYOL":
        train_BYOL(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "MoCo":
        train_MoCo(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "JMI":
        train_JMI(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "BarlowTwins":
        train_BarlowTwins(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "VICReg":
        train_VICReg(args, logger, model, train_loader, optimizer, epoch, device)
    elif args.model_name == "SimSiam_SDMI":
        train_SimSiam_SDMI(args, logger, model, train_loader, optimizer, epoch, device)
    else:
        raise ValueError(f"Unknown model name: {args.model_name}")


def train_supervised(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []

    for images, labels in train_loader:
        images = images.to(device, non_blocking=True) 
        labels = labels.to(device, non_blocking=True)

        logits = model(images)
        loss = cross_entropy_loss(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.detach().item())

    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')
    

def train_SDMI(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    optimizer_E, optimizer_M = optimizer
    E_batch_losses = []
    M_batch_losses = []

    # E-Step
    model.E_encoder.train()
    model.E_projector.train()
    model.M_encoder.eval()
    model.M_projector.eval()

    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        E_representation_1, E_representation_2, M_representation_1, M_representation_2 = model(view_1, view_2)
        M_1_detached, M_2_detached = M_representation_1.detach(), M_representation_2.detach()
        loss = 0.5 * (
            dv_bound_loss(E_representation_1, M_2_detached, args.temperature) 
            + dv_bound_loss(E_representation_2, M_1_detached, args.temperature)
        )
        
        optimizer_E.zero_grad()
        loss.backward()
        optimizer_E.step()

        E_batch_losses.append(loss.detach().item())

    # M-Step
    model.E_encoder.eval()
    model.E_projector.eval()
    model.M_encoder.train()
    model.M_projector.train()
    
    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        E_representation_1, E_representation_2, M_representation_1, M_representation_2 = model(view_1, view_2)
        E_1_detached, E_2_detached = E_representation_1.detach(), E_representation_2.detach()
        loss = 0.5 * (
            dv_bound_loss(M_representation_1, E_2_detached, args.temperature) 
            + dv_bound_loss(M_representation_2, E_1_detached, args.temperature)
        )
        
        optimizer_M.zero_grad()
        loss.backward()
        optimizer_M.step()

        M_batch_losses.append(loss.detach().item())

    E_average_loss = torch.tensor(E_batch_losses).mean().item()
    model.E_loss_history.append(E_average_loss)

    M_average_loss = torch.tensor(M_batch_losses).mean().item()
    model.M_loss_history.append(M_average_loss)
 
    logger.info(f'Epoch-{epoch} | E-Step Loss: {E_average_loss:.6f} | M-Step Loss: {M_average_loss:.6f}')
    

def train_SimSiam(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []
    
    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        prediction_1, prediction_2, representation_1, representation_2 = model(view_1, view_2)
        loss = 0.5 * (
            cosine_similarity_loss(prediction_1, representation_2) 
            + cosine_similarity_loss(prediction_2, representation_1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.detach().item())
    
    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')


def train_BYOL(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []

    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        prediction_1, prediction_2, target_representation_1, target_representation_2 = model(view_1, view_2)
        loss = 0.5 * (
            mse_loss(prediction_1, target_representation_2) 
            + mse_loss(prediction_2, target_representation_1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model.update_target_encoder(args)

        batch_losses.append(loss.detach().item())

    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')


def train_MoCo(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []

    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        query_1, query_2, key_1, key_2 = model(view_1, view_2)
        loss = 0.5 * (
            contrastive_loss(query_1, key_2, args.temperature) 
            + contrastive_loss(query_2, key_1, args.temperature)
        )
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model.update_momentum_encoder(args)

        batch_losses.append(loss.detach().item())

    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')


def train_JMI(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []

    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        representation_1, representation_2 = model(view_1, view_2)
        loss = 0.5 * (
            dv_bound_loss(representation_1, representation_2, args.temperature) 
            + dv_bound_loss(representation_2, representation_1, args.temperature)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.detach().item())

    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')


def train_SimCLR(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []
    
    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        representation_1, representation_2 = model(view_1, view_2)
        loss = 0.5 * (
            infoNCE_loss(representation_1, representation_2, args.temperature)
            + infoNCE_loss(representation_2, representation_1, args.temperature)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.detach().item())
    
    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')


def train_SimSiam_SDMI(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []

    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        representation_1, representation_2 = model(view_1, view_2)
        loss = 0.5 * (
            dv_bound_loss(representation_1, representation_2.detach(), args.temperature) 
            + dv_bound_loss(representation_2, representation_1.detach(), args.temperature)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.detach().item())

    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')


def train_BarlowTwins(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []

    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        representation_1, representation_2 = model(view_1, view_2)
        loss = 0.5 * (
            diversity_loss(representation_1, representation_2, args.lambd) 
            + diversity_loss(representation_2, representation_1, args.lambd)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.detach().item())

    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')


def train_VICReg(args, logger, model, train_loader, optimizer, epoch, device):
    model.train()
    batch_losses = []

    for view_1, view_2 in train_loader:
        view_1 = view_1.to(device, non_blocking=True, dtype=torch.float)
        view_2 = view_2.to(device, non_blocking=True, dtype=torch.float)

        representation_1, representation_2 = model(view_1, view_2)
        loss = 0.5 * (
            VICReg_loss(representation_1, representation_2, args.similarity_coeff, args.variance_coeff, args.covariance_coeff, args.eps) 
            + VICReg_loss(representation_2, representation_1, args.similarity_coeff, args.variance_coeff, args.covariance_coeff, args.eps)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.detach().item())

    average_loss = torch.tensor(batch_losses).mean().item()
    model.loss_history.append(average_loss)
    logger.info(f'Epoch-{epoch} | Loss: {average_loss:.6f}')