import torch
import numpy as np
import wandb
from omegaconf import DictConfig
from hydra.utils import instantiate
from .compute_clf_preds import get_output_path
from tqdm import tqdm
from pathlib import Path

import utils

from dataset.base import CustomMaskDataset, CustomMask3DDataset

import logging
log = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric

def get_components(config, fabric):
    # instantiate models
    classifier = instantiate(config.classifier)
    evaluator = instantiate(config.evaluation)(model = classifier)

    inpainter = instantiate(config.inpainter)(guidance = None)

    # compile modules with torch
    inpainter = torch.compile(inpainter)
    # evaluator = torch.compile(evaluator)

    # setup modules with fabric
    evaluator = fabric.setup(evaluator)
    inpainter = fabric.setup(inpainter)

    return inpainter, evaluator, classifier

def get_dataloader(config, fabric):
    # Set path_predictions on the run based on the classifier and dataset
    if config.dataset.dataset.path_predictions is None:
        log.info('Auto-Setting path_predictions')
        path_output = get_output_path(config)
        predictions_path: Path = path_output / 'data.csv'
        log.info(f'path_predictions set to: {predictions_path}')
        assert predictions_path.exists(), f"Predictions file not found at {predictions_path}"
        config.dataset.dataset.path_predictions = predictions_path.absolute()

    return fabric.setup_dataloaders(instantiate(config.dataset))


def run(config: DictConfig):
    utils.preprocess_config(config)
    utils.setup_wandb(config)

    log.info('Launching Fabric')
    fabric = get_fabric(config)
    batch_multip = config.exp.batch_multip
    
    log.info('Building components')
    inpainter, evaluator, classifier = get_components(config, fabric)

    log.info('Initializing dataloader')
    dataloader = get_dataloader(config, fabric)
    ds = dataloader.dataset
    
    with fabric.init_tensor():

        for idx, batch in tqdm(enumerate(dataloader), total = len(dataloader), desc = 'Batches'):
            log.info(f'Batch: {idx}')

            batch_idx, batch_labels, batch_pred_labels, orig, ts = batch
            
            nodule_mask, nodule_true_region, nodule_size = ds.get_nodule()

            target = utils.get_target_id(config, batch_pred_labels).long() # Make sure targets are ints

            # import ipdb;ipdb.set_trace()

            # Calculate max/min ranges of the mask
            # xyz = torch.where(ts.squeeze() == 1)
            # min_ranges = np.array([x.min().cpu() for x in xyz])
            # max_ranges = np.array([x.max().cpu() for x in xyz])

            # Calculate the medians of where the mask begins (min) and ends (max) for each axis
            mask = (ts.squeeze().cpu().numpy() == 1)
            mask_rev = np.flip(ts.squeeze().cpu().numpy() == 1) # Reversed mask for calculating mask endings

            min_ranges = np.array([np.nanmedian(np.where(mask.any(axis=axis), mask.argmax(axis=axis), np.nan)) for axis in range(len(mask.shape))])
            max_ranges = np.array([np.nanmedian(np.where(mask_rev.any(axis=axis), mask_rev.argmax(axis=axis), np.nan)) for axis in range(len(mask.shape))])
            max_ranges = mask.shape - max_ranges # Un-reverse the median

            grid_offsets = [int(shape / steps) for shape, steps in zip((max_ranges-min_ranges), config.exp.grid_steps)]
            grid_placements = np.array(np.meshgrid(*[np.arange(x) for x in config.exp.grid_steps])).T.reshape(-1, 3)

            centers = [[int(offset * (coord + 0.5) + min_off) for offset, min_off, coord in zip(grid_offsets, min_ranges, xyz)] for xyz in grid_placements]

            for xyz in tqdm(grid_placements, desc = 'Grid positions'):
                center = [int(offset * (coord + 0.5) + min_off) for offset, min_off, coord in zip(grid_offsets, min_ranges, xyz)]
                center = ds.get_correct_center(center, orig.squeeze().shape)
                wandb.log({'grid_pos': "_".join([str(x) for x in xyz]),
                           'center_pos': "_".join([str(x) for x in center])})

                print(xyz)
                print(center)
                print(grid_offsets)
                
                i2sb_region, i2sb_mask, inj_pos = ds.prepare_injection(orig.squeeze(),\
                                                                        ts.squeeze(),\
                                                                        nodule_mask,\
                                                                        nodule_true_region,\
                                                                        nodule_size,\
                                                                        center)
                batch_injects = i2sb_region.unsqueeze(0)
                batch_masks = i2sb_mask.unsqueeze(0)

                classifier.set_precondition_imgs(orig.unsqueeze(0), [inj_pos])

                utils_logger = utils.VideoUtilsLogger()
                utils_logger.log_original(batch_idx, batch_injects)
                utils_logger.log_original_mask_overlay(batch_idx, batch_injects, batch_masks)

                batch_injects = batch_injects.float()  

                ## 2. inpainting
                log.info('Beginning inpainting')
                batch_inps = [None] * batch_multip
                batch_guidance_classes = [None] * batch_multip
                batch_imgs_rep, batch_maps_rep = utils.get_batch_to_inp(
                    config, batch_injects, batch_masks)
                
                for idx_inner in range(batch_multip):
                    log.info(f'Inpainting batch: {idx_inner + 1}/{batch_multip}')
                    batch_guidance_classes_ = dataloader.dataset.get_guidance_classes(
                        config, fabric, batch_labels, batch_pred_labels)    
                    with torch.no_grad():             
                        batch_inps_ = inpainter.inpaint(
                            batch_imgs_rep, batch_maps_rep, batch_guidance_classes_)
                    batch_inps[idx_inner] = batch_inps_
                    batch_guidance_classes[idx_inner] = batch_guidance_classes_

                batch_inps = torch.cat(batch_inps)
                batch_guidance_classes = torch.cat(batch_guidance_classes)
                utils_logger.log_inpaints(batch_idx, batch_inps)

                ## 3. evaluate
                log.info('Evaluating inpaints')
                eval_input = {
                    'batch_idx': batch_idx.float().cuda(),
                    'batch_imgs': batch_injects.float().cuda(),
                    'batch_inps': batch_inps.float().cuda(),
                    'batch_labels': batch_labels, 
                    'batch_pred_labels': batch_pred_labels,
                    'batch_guidance_classes': batch_guidance_classes}
                evaluator.evaluate(config, eval_input)

        inpainter.on_end()