import torch
from torchvision import datasets, transforms
from torch.utils.data import DistributedSampler
import torch.distributed as dist
from tqdm import tqdm
import torchmetrics
import sys

from src.models import get_model
from src.optimizers import get_opt
from src.utils import utils

def create_datasets_and_loader(dataset='MNIST', batch_size=64, root='./data', 
                               download=True, max_samples=None, train_split=0.9, 
                                return_datasets = False,
                                distributed=False):
    
    # Select relevant train and test datasets before doing train/val split and creating loaders
    if dataset == 'MNIST':
        tensor_transform = transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ])
        train_data = datasets.MNIST(root=root, train=True, transform=tensor_transform, download=download)
        test_dataset = datasets.MNIST(root=root, train=False, transform=tensor_transform, download=download)
        
    elif dataset == 'FashionMNIST':
        tensor_transform = transforms.ToTensor()
        train_data = datasets.FashionMNIST(root=root, train=True, transform=tensor_transform, download=download)
        test_dataset = datasets.FashionMNIST(root=root, train=False, transform=tensor_transform, download=download)
        
    elif dataset == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
            ])

        train_data = datasets.CIFAR10(root=root, train=True, download=download, transform=transform_train)
        test_dataset = datasets.CIFAR10(root=root, train=False, download=download, transform=transform_test)

    elif dataset == 'Tiny ImageNet':
        train_transforms = transforms.Compose([
            transforms.RandomCrop(64, padding=8),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.480, 0.448, 0.398],
                                std=[0.277, 0.269, 0.282])
        ])
        val_transforms = transforms.Compose([
                                            transforms.Resize(64),
                                            transforms.CenterCrop(64),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean=[0.480, 0.448, 0.398],
                                                                std=[0.277, 0.269, 0.282])
                                        ])

        train_data = datasets.ImageFolder(root="./data/tiny-imagenet-200/train", transform=train_transforms)
        val_data   = datasets.ImageFolder(root="./data/tiny-imagenet-200/train", transform=val_transforms)
        test_dataset = datasets.ImageFolder(root="./data/tiny-imagenet-200/val", transform=val_transforms)
                
        
    else:
        raise NotImplementedError(f'Did not recognize dataset "{dataset}"')
        
    # Potentially consider subset of data (primarily when debugging)
    if max_samples:
        train_indices = list(range(0, min(max_samples, len(train_data))))
        train_data = torch.utils.data.Subset(train_data, train_indices)
        
        test_indices = list(range(0, min(max_samples, len(test_dataset))))
        test_dataset = torch.utils.data.Subset(test_dataset, test_indices)


    # Split train_data into train and validation sets based on train_split
    train_size = int(len(train_data) * train_split)
    val_size = len(train_data) - train_size
    if dataset == 'Tiny ImageNet':
        indices = torch.randperm(len(train_data))
        train_indices = indices[:train_size]
        val_indices   = indices[train_size:]

        train_dataset = torch.utils.data.Subset(train_data, train_indices)
        val_dataset   = torch.utils.data.Subset(val_data, val_indices)
    else:
        train_dataset, val_dataset = torch.utils.data.random_split(train_data, [train_size, val_size])
    
    # Prepare for distributed learning
    if distributed:
        train_sampler = DistributedSampler(train_dataset)
    else:
        train_sampler = None
    
    # Create relevant loaders
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                               shuffle=(train_sampler is None), sampler=train_sampler, pin_memory=True, num_workers=8)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False,pin_memory=True, num_workers=8)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True,num_workers=8)

        
    if return_datasets:
        return (train_loader, val_loader, test_loader), (train_dataset, val_dataset, test_dataset)
    else:
        return train_loader, val_loader, test_loader


def get_loaders(conf, distributed=False, download=False):
    train_dataloader, val_dataloader, test_dataloader = create_datasets_and_loader(
        dataset=conf.dataset, download=download, max_samples=conf.max_samples, distributed=distributed)
    return train_dataloader, val_dataloader, test_dataloader

def get_train_objs(conf, device=None):
    if not device:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(conf).to(device)
    opt, lr_scheduler = get_opt(conf, model)
    criterion = utils.get_loss(conf)
    return model, opt, lr_scheduler, criterion




DO_NOT_DISPLAY_METRIC = ['layer_sparsity']

def train_step(epoch, model, criterion, opt, dataloader,  device, writer=None,
               metric_collection=None, distributed=False):
    # Prepare for an epoch of training
    model.train()
    running_loss = 0.0
    global_step = epoch * len(dataloader)

    is_main_process = not dist.is_initialized() or dist.get_rank() == 0
    on_gpu = not device.type == 'cpu'
    
    # if is_main_process:
    pbar = tqdm(enumerate(dataloader), total=len(dataloader),
                desc=f"Epoch {epoch}", leave= on_gpu, 
                file=sys.stderr # Send prints to the error file, not output file
                , dynamic_ncols= not on_gpu # Ensure pretty prints on HPC
                , disable = not is_main_process
                )
    for batch_idx, (x, y) in pbar:
        x, y = x.to(device), y.to(device)
        
        # Forward pass
        outputs = model(x)
        
        # Compute metrics
        loss = criterion(outputs, y)
        if metric_collection:
            _, preds = torch.max(outputs, 1)
            metric_collection.update(preds=preds, target=y, model=model)
        
        # Update the current model
        model.zero_grad()
        loss.backward()
        opt.step()
        
        # Account for distributed learning
        loss_val = loss.detach().clone()
        if distributed:
            torch.distributed.all_reduce(loss_val, op=torch.distributed.ReduceOp.SUM)
            loss_val /= torch.distributed.get_world_size()
            
        # End of batch logging
        running_loss += loss_val.item()
        if writer:
            writer.add_scalar('Train/loss', loss_val.item(), global_step + batch_idx)
            writer.add_scalar('epoch', epoch, global_step + batch_idx)
        display_str = f"Loss: {loss_val.item():.4f}"
        if metric_collection:
            results = metric_collection.compute()
            for name, value in results.items():
                if writer:
                    writer.add_scalar(f'Train/{name}', value, global_step + batch_idx) 
                if not any(skip in name for skip in DO_NOT_DISPLAY_METRIC):
                    display_str += f" | {name}: {value:.3f}"
                    
        if is_main_process:          
            pbar.set_postfix_str(display_str) # Update tqdm progress bar
    
    # End of epoch logging
    avg_loss = running_loss / len(dataloader)
    if writer:
        writer.add_scalar('Epoch/avg_loss', avg_loss, epoch)
    
        # Reset internal state to be ready for new data
    if metric_collection:
        metric_collection.reset()


def validation_step(epoch, model, criterion, dataloader,  device, writer,
                    metric_collection=None,  pbar=None):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for batch_idx, (x,y) in enumerate(dataloader):
            x,y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            
            running_loss += loss.item()
            if metric_collection:
                preds = torch.argmax(logits, dim=1)
                # Model-based metrics (e.g. sparsity) should have a flag to 
                #   prevent recomputation every batch
                metric_collection.update(preds=preds, target=y, model=model)
        
    avg_loss = running_loss / len(dataloader)
    if writer:
        writer.add_scalar('Val/avg_loss', avg_loss, epoch)
    display_str = f"| val_avg_loss: {avg_loss:.3f}"
    if metric_collection:
        results = metric_collection.compute()
        for name, value in results.items():
            writer.add_scalar(f'Val/{name}', value, epoch)
            if name == 'accuracy':
                display_str += f" | val_{name}: {value:.3f}"
            elif not any(skip in name for skip in DO_NOT_DISPLAY_METRIC):
                display_str += f" | {name}: {value:.3f}"
            
        # Reset internal state to be ready for new data
        metric_collection.reset()
    if pbar: # Update training bar if provided
        pbar.set_postfix_str(display_str)

    return avg_loss



def test(model, dataloader, device, verbose=True, num_classes = 10):
    accuracy_metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to(device)
    
    # Put model in eval mode
    model.eval()
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, disable=not verbose):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            accuracy_metric.update(preds, labels)
    
    # Compute final accuracy
    accuracy = accuracy_metric.compute().item() * 100
    if verbose:
        print(f"\nTest Accuracy: {accuracy:.2f}%")
    return accuracy
