import os

import IPython
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from logger import Logger
from visualize import visualize
import wandb
from typing import List

def train_eval_loop(
    model: nn.Module, 
    optimizer: torch.optim.Optimizer, 
    train_loader: DataLoader, 
    val_loader: DataLoader, 
    epochs: int, 
    device: torch.device, 
    run_name: str,
    project_path: str, 
    print_log_freq: int = 100, 
    image_log_freq: int = 1000, 
    num_images_log: int = 8, 
    current_epoch: int = 0, 
    use_wandb: bool = True,
    configs: dict = {}
): 
    """
    Train and evaluate the model for several epochs.

    Args:
        model: model to train
        optimizer: optimizer to use
        train_position_loader: dataloader for training position predictions
        train_action_loader: dataloader for training action predictions
        val_position_loader: dataloader for evaluating position predictions
        val_action_loader: dataloader for evaluating action predictions
        epochs: number of epochs to train
        device: device to train on
        project_path: folder to save checkpoints and logs
        log_freq: frequency of logging to wandb
        image_log_freq: frequency of logging images to wandb
        num_images_log: number of images to log to wandb
        current_epoch: epoch to start training from
        alpha: tradeoff between distance and action loss
        use_wandb: whether to log to wandb or not
    """
    latest_path = os.path.join(project_path, f'latest.pth')
    
    best_eval_avg_loss = 1e100
    
    for epoch in range(current_epoch, current_epoch + epochs):
        print(f'Start training epoch {epoch}/{current_epoch + epochs - 1}')
        
        train(
            model, 
            optimizer, 
            train_loader, 
            device, 
            run_name, 
            project_path, 
            epoch, 
            print_log_freq, 
            image_log_freq, 
            num_images_log, 
            use_wandb,
            configs,
        )
        
        eval_avg_loss = evaluate(
            model, 
            val_loader, 
            device, 
            run_name, 
            project_path, 
            epoch, 
            print_log_freq,
            image_log_freq,
            num_images_log,
            use_wandb,
            configs,
        )
        
        # Save model
        checkpoint = {
            "epoch": epoch,
            "model": model,
            "optimizer": optimizer,
            "eval_avg_loss": eval_avg_loss,
        }

        if (epoch + 1) % 10 == 0:
            numbered_path = os.path.join(project_path, f"{epoch}.pth")
            torch.save(checkpoint, numbered_path)
        torch.save(checkpoint, latest_path)

        best_path = os.path.join(project_path, 'best.pth')
        if eval_avg_loss < best_eval_avg_loss:
            best_eval_avg_loss = eval_avg_loss
            torch.save(checkpoint, best_path)
        
def train(
    model: nn.Module, 
    optimizer: torch.optim.Optimizer, 
    train_loader: DataLoader, 
    device: torch.device, 
    run_name: str,
    project_path: str, 
    epoch: int, 
    print_log_freq: int, 
    image_log_freq: int, 
    num_images_log: int, 
    use_wandb: bool,
    configs: dict,
):
    """
    Train the model for one epoch. 
    """

    model.train()

    # Build logger for training
    loggers: List[Logger] = []
    if len(configs["gpu_ids"]) > 1:
        for loss_name in model.module.loss_names:
            loggers.append(Logger(loss_name, 'train', window_size=print_log_freq))
    else:
        for loss_name in model.loss_names:
            loggers.append(Logger(loss_name, 'train', window_size=print_log_freq))
    
    num_batches = len(train_loader)
    batch_size = configs['batch_size']
    # Training
    for iteration, data in enumerate(train_loader):
        # Move to cuda device
        for key, tensor in data.items():
            data[key] = tensor.to(device)
        
        outputs = model(data)

        # Optimize model parameters
        optimizer.zero_grad()
        loss = outputs['loss']
        if len(configs["gpu_ids"]) > 1:
            loss.sum().backward()
        else:
            loss.backward()
        optimizer.step()

        # IPython.embed()
        # Update logger stats
        for logger in loggers: 
            if len(configs["gpu_ids"]) > 1:
                logger.log_data(outputs[logger.name].mean().item())
            else:
                logger.log_data(outputs[logger.name].item())
        
        # Log to weight and bias
        if use_wandb:
            data_log = {}
            for logger in loggers:
                data_log[logger.full_name()] = logger.latest()
            total_samples = (epoch * num_batches + iteration) * batch_size
            navigation_epochs = total_samples / train_loader.dataset.navigation_total_steps
            data_log.update({'navigation_train_epochs': navigation_epochs})
            wandb.log(data_log, commit=False)
            
        if iteration % print_log_freq == 0:
            log_display = f"(epoch {epoch}) (batch {iteration}/{num_batches - 1}) "
            for logger in loggers:
                print(log_display + logger.display())
            print()

        if iteration % image_log_freq == 0:
            visualize(data, outputs, run_name, 'train', epoch, iteration, use_wandb)
        
        if use_wandb:
            # Commit the wandb logs
            wandb.log({})
            
def evaluate(
    model: nn.Module, 
    val_loader: DataLoader, 
    device: torch.device, 
    run_name: str,
    project_path: str,
    epoch: int, 
    print_log_freq: int,
    image_log_freq: int,
    num_images_log: int,
    use_wandb: bool,
    configs: dict,
):
    """
    Evaluate the model on the given evaluation dataset.
    """
    model.eval()

    # Build logger for evaluation
    loggers: List[Logger] = []
    if len(configs["gpu_ids"]) > 1:
        for loss_name in model.module.loss_names:
            loggers.append(Logger(loss_name, 'val', window_size=print_log_freq))
    else:
        for loss_name in model.loss_names:
            loggers.append(Logger(loss_name, 'val', window_size=print_log_freq))

    num_batches = len(val_loader)
    # Evaluation
    with torch.no_grad():
        for iteration, data in enumerate(val_loader):
            # Move to cuda device
            for key, tensor in data.items():
                data[key] = tensor.to(device)
            
            outputs = model(data)

            # Update logger stats
            for logger in loggers:
                if len(configs["gpu_ids"]) > 1:
                    logger.log_data(outputs[logger.name].mean().item())
                else:
                    logger.log_data(outputs[logger.name].item())

            if iteration % print_log_freq == 0:
                log_display = f"(epoch {epoch}) (batch {iteration}/{num_batches - 1}) "
                for logger in loggers:
                    print(log_display + logger.display())
                print()

            if iteration % image_log_freq == 0:
                visualize(data, outputs, run_name, 'val', epoch, iteration, use_wandb)

    # Log to weight and bias
    data_log = {}
    for logger in loggers:
        log_display = f"(epoch {epoch}) "
        data_log[logger.full_name()] = logger.average()
        print(log_display + logger.display())
    print()
    if use_wandb:
        wandb.log(data_log, commit=False)
    
    # Return avg loss
    eval_avg_loss = None
    for logger in loggers:
        if logger.name == 'loss':
            eval_avg_loss = logger.average()
    return eval_avg_loss

def load_model(model, checkpoint: dict) -> None:
    """Load model from checkpoint."""
    loaded_model = checkpoint["model"]
    state_dict = loaded_model.state_dict()
    model.load_state_dict(state_dict)
    
def get_saved_optimizer(
    checkpoint: dict, device: torch.device
) -> torch.optim.Optimizer:
    optimizer = checkpoint["optimizer"]
    optimizer_to(optimizer, device)
    return optimizer

def optimizer_to(optim, device):
    """Move optimizer state to device."""
    for param in optim.state.values():
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)
