from pathlib import Path
from functools import partial
import cc3d

import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from hydra.utils import instantiate
from lightning.fabric import Fabric
from omegaconf import DictConfig
from shapiq import ExactComputer
from torch.nn import functional as F
from tqdm import tqdm

import utils
import wandb
from dataset.base import DatasetWithRegionsInAnOrgan
from utils import shnap as shnap_utils
from utils.segmentations import roi_center_crop_slices, find_segment_centers, map_segments
from utils.misc import list_from_string
from utils.region_extractor import CTRegionExtractor

from .compute_clf_preds import get_output_path


import logging
log = logging.getLogger(__name__)


def get_fabric(config):
    fabric = instantiate(config.fabric)
    fabric.seed_everything(config.exp.seed)
    fabric.launch()
    return fabric


def get_components(config, fabric):
    # instantiate models
    classifier = instantiate(config.classifier)
    inpainter = instantiate(config.inpainter)(guidance = None)

    # setup classifier
    classifier = fabric.setup(classifier)
    classifier.mark_forward_method("forward_all_years")
    classifier.eval()
    # classifier = torch.compile(classifier)

    # setup modules with fabric
    inpainter = fabric.setup(inpainter)
    inpainter = torch.compile(inpainter)
    inpainter.mark_forward_method("inpaint")

    return inpainter, classifier


def get_dataloader(config, fabric):
    return fabric.setup_dataloaders(instantiate(config.dataset))


def run(config: DictConfig):
    torch.set_float32_matmul_precision("high")

    # TODO: classifier guidance in reverse direction of target_id
    utils.preprocess_config(config)
    utils.setup_wandb(config)

    log.info('Launching Fabric')
    fabric: Fabric = get_fabric(config)

    log.info('Building components')
    inpainter, classifier = get_components(config, fabric)

    log.info('Initializing dataloader')
    dataloader = get_dataloader(config, fabric)
    dataset = dataloader.dataset
    utils_logger = utils.VideoUtilsLogger()

    log.info("Initializing region extractor")
    region_extractor: CTRegionExtractor = instantiate(config.region_extractor)

    log.info("Initializing ShapIQ plots")
    shapiq_plots_factory = instantiate(config.plots_factory)
    shapiq_plot_years = list_from_string(config.exp.shapiq_plot_years, dtype=int)

    log.info("Initializing our plot functions")
    create_risks_plot = partial(shnap_utils.grouped_data_bar_plot,
        y_label="Risk", title="Risk values when including a given region",
        regions_start_idx=1, bg_strategy=config.exp.bg_masks
    )
    create_deltas_plot = partial(shnap_utils.grouped_data_bar_plot,
        y_label="Delta", title="Risk delta when including a given region",
        regions_start_idx=0, bg_strategy=config.exp.bg_masks
    )
    create_shnaps_plot = partial(shnap_utils.grouped_data_bar_plot,
        y_label="shnap", title="shnap values associated with a given region",
        regions_start_idx=0, bg_strategy=config.exp.bg_masks
    )

    # Mask strategy for nodule segmentation (or attention)
    mask_strategy: shnap_utils.MASK_STRATEGY = config.exp.main_masks
    # Background strategy for shnap
    bg_strategy: shnap_utils.BG_STRATEGY = config.exp.bg_masks


    @torch.no_grad()
    def get_region_mask(nodule_seg: torch.Tensor, attention: torch.Tensor, for_bg: bool = False):
        if for_bg:
            return (nodule_seg > 0) | (attention > config.exp.attention_threshold)

        if mask_strategy == "original":
            return nodule_seg
        elif mask_strategy == "attention":
            return attention > config.exp.attention_threshold
        else:
            raise ValueError(f"Unknown mask strategy: {mask_strategy}")


    def _get_save_dir(base_dir: str, metadata: dict):
        save_dir = Path(base_dir)
        if "patient_id" in metadata:
            save_dir = save_dir / str(metadata["patient_id"])
        save_dir.mkdir(parents=True, exist_ok=True)
        return save_dir

    def _symlink_to_original(base_dir: Path, metadata: dict):
        filepath = base_dir / f"{str(metadata['series_id'])}_image.nii.gz"
        filepath = filepath.resolve()
        if not filepath.exists():
            original_path = Path(metadata['image_path']).resolve()
            filepath.symlink_to(original_path)
        return None

    @torch.no_grad()
    def save_segmentation(segmentation: torch.Tensor, metadata: dict):
        """
        segmentation: (W, H, D) torch tensor
        metadata: dict
        """
        save_dir = _get_save_dir(config.exp.seg_save_dir, metadata)
        filename = f"{str(metadata['series_id'])}_seg.nii.gz"
        filepath = save_dir / filename
        log.info(f"Saving segmentation mask to {filepath}")

        transformed_segmentation = dataset.reverse_mask_transform(
            segmentation, metadata["original_affine"], as_nifty=True
        )
        nib.save(transformed_segmentation, filepath)
        _symlink_to_original(save_dir, metadata)

    @torch.no_grad()
    def save_inpainted_image(image: torch.Tensor, metadata: dict):
        """
        image: (W, H, D) torch tensor
        metadata: dict
        """
        save_dir = _get_save_dir(config.exp.inp_save_dir, metadata)
        filename = f"{str(metadata['series_id'])}_inpaint.nii.gz"
        filepath = save_dir / filename
        log.info(f"Saving inpainted image to {filepath}")

        transformed_image = dataset.reverse_transform(
            image, metadata["original_affine"], as_nifty=True
        )
        nib.save(transformed_image, filepath)

        # Symlink the original image for reference
        _symlink_to_original(save_dir, metadata)
        return None


    log.info("Starting experiment")
    with fabric.init_tensor():
        for idx, batch in tqdm(enumerate(dataloader), total = len(dataloader), desc = 'Batches'):
            torch.cuda.empty_cache()
            log.info(f'Batch: {idx}')

            ########################################################
            # Data loading: CT images, nodule segmentations, lung masks
            ########################################################
            if isinstance(dataloader.dataset, DatasetWithRegionsInAnOrgan):
                batch_idxs, batch_img, _, _, batch_lungs_mask, batch_nodule_seg, metadata = batch
                assert len(batch_img) == 1, 'Batch size must be 1'

                # Unbatch nodules
                nodule_seg = batch_nodule_seg[0] # (W, H, D)

                # Get attention
                batch_attention = classifier.get_attention(batch_img)[0]
                attention = batch_attention[0, 0] # (W, H, D)

                regions_mask_or_seg = get_region_mask(nodule_seg, attention, for_bg=(bg_strategy == "bgonly"))

                # Extract main regions (ensure multiclass mask)
                regions_seg, n_regions = region_extractor.get_region_segments(regions_mask_or_seg)
                if regions_seg is None:
                    log.warning("No regions found in the segmentation")
                    continue
                regions_seg = fabric.to_device(regions_seg)

                # Handle the lungs mask
                if batch_lungs_mask is not None:
                    lungs_mask = batch_lungs_mask[0] # (W, H, D)
                    lungs_mask = fabric.to_device(lungs_mask)
                else:
                    bg_strategy = "ignore" # No lung mask provided

                # Handle metadata
                metadata: dict = metadata[0]

                # Cleanup
                del batch_nodule_seg, batch_lungs_mask, batch_attention
                torch.cuda.empty_cache() 
            else:
                raise ValueError("Dataset must be a DatasetWithRegionsInAnOrgan")


            ########################################################
            # Handling of background regions
            ########################################################
            if bg_strategy == "ignore":
                lungs_seg = None
            else:
                # Extract lung regions
                lungs_seg, lung_region_centers = region_extractor.get_lung_region_segments(
                    lungs_mask,
                    regions_seg
                )
                if lungs_seg is not None:
                    segment_map = region_extractor.get_optimal_bg_subset(
                        lung_region_centers,
                        lungs_mask,
                        nodule_centers=find_segment_centers(regions_seg)
                    )
                    if segment_map is not None:
                        lungs_seg, _ = map_segments(lungs_seg, segment_map)
                else:
                    log.warning("No lung regions found in the segmentation")

                if lungs_seg is None:
                    continue

                if bg_strategy == "bgonly":
                    regions_seg = lungs_seg
                    n_regions = regions_seg.max().int().item()
                    lungs_seg = None

            ########################################################
            # Save the segmentation masks in Nifti format
            ########################################################
            if config.exp.seg_save_dir is not None:
                save_segmentation(regions_seg.clone(), metadata)

            ########################################################
            # Inpainting regions
            ########################################################
            with torch.no_grad():
                regions = []
                regions_masks = []
                region_inps = []
                region_attention_sum = []
                region_attention_max = []

                batch_img_inp = batch_img.clone() # (1, 1, W, H, D)
                def do_region_inpaint(region_mask):
                    mask_slices = roi_center_crop_slices(region_mask) # 3D
                    batch_slices = [slice(None), slice(None)] + mask_slices # 5D

                    batch_mask = region_mask.unsqueeze(0).unsqueeze(0) # (1, 1, W, H, D)
                    batch_mask = batch_mask[batch_slices] # (1, 1, W', H', D')
                    batch_region = batch_img[batch_slices].clone() # (1, 1, W', H', D')

                    utils_logger.log_original(batch_idxs, batch_region)
                    utils_logger.log_attr_maps(batch_idxs, batch_mask)
                    utils_logger.log_original_mask_overlay(batch_idxs, batch_region, batch_mask)
                    
                    batch_region_inp = inpainter.inpaint(batch_region, batch_mask, None)
                    
                    utils_logger.log_inpaints(batch_idxs, batch_region_inp)

                    # Replace the region in the original image
                    batch_img_inp[batch_slices] = batch_img_inp[batch_slices] * (1 - batch_mask) + batch_region_inp * batch_mask

                    regions.append(batch_region[0])
                    regions_masks.append(batch_mask[0])
                    region_inps.append(batch_region_inp[0])
                    del batch_region, batch_mask, batch_region_inp


                for region_idx in tqdm(range(1, n_regions + 1), desc = 'Region inpainting'):
                    region_mask = (regions_seg == region_idx)
                    region_attention_sum.append(torch.sum(attention[region_mask]).detach().cpu().item())
                    region_attention_max.append(torch.max(attention[region_mask]).detach().cpu().item())
                    do_region_inpaint(region_mask.float())

                if lungs_seg is not None:
                    for bg_idx in tqdm(range(1, lungs_seg.max().int().item() + 1), desc = 'Background region inpainting'):
                        region_mask = (lungs_seg == bg_idx)
                        region_attention_sum.append(torch.sum(attention[region_mask]).detach().cpu().item())
                        region_attention_max.append(torch.max(attention[region_mask]).detach().cpu().item())
                        do_region_inpaint(region_mask.float())

                regions = torch.stack(regions) # (N, 1, W', H', D')
                regions_masks = torch.stack(regions_masks) # (N, 1, W', H', D')
                region_inps = torch.stack(region_inps) # (N, 1, W', H', D')
                video = utils.create_single_video_summary(
                    regions, regions_masks, region_inps
                )

            ########################################################
            # Saving the inpainted image
            ########################################################
            if config.exp.inp_save_dir is not None:
                save_inpainted_image(batch_img_inp[0, 0].clone(), metadata)

            ########################################################
            # Combining regions for shnap calculations
            ########################################################
            regions_seg = regions_seg.clone()
            if lungs_seg is not None:
                if bg_strategy == "ignore":
                    pass
                elif bg_strategy == "concat":
                    regions_seg += (lungs_seg + n_regions * (lungs_seg > 0))
                elif bg_strategy == "combine":
                    regions_seg += (n_regions + 1) * (lungs_seg > 0)
                else:
                    raise ValueError(f"Unknown background strategy: {bg_strategy}")


            ########################################################
            # Calculating shnap values
            ########################################################
            torch.cuda.empty_cache() # Clear GPU memory
            true_idx = batch_idxs[0].cpu().item()

            # Evaluator will calculate v-values for each year
            # and cache the results for each shapiq game
            evaluator = shnap_utils.shnapCoalitionEvaluator(
                classifier=classifier,
                original=batch_img,
                inpaint=batch_img_inp,
                region_seg=regions_seg,
                batch_size=config.exp.v_values_batch_size,
                bg_start_idx=n_regions,  # Backgrounds start after nodules
                bg_strategy=bg_strategy,
            )

            shnap_values = []
            for year in range(classifier.num_years):
                # SHAPIQ plots for base risk and year1 risk only (rest is redundant)
                shnap_game = shnap_utils.shnapGame(evaluator, value_idx=year)
                exact_computer = ExactComputer(n_players=shnap_game.n_players, game=shnap_game)
                sv_exact = exact_computer(index="SV")

                sii_exact = exact_computer(index="k-SII")
                shnap_values.append(sv_exact.get_n_order_values(1))

                if year == 0:
                    base_interactions = sii_exact.get_n_order_values(2)
                    np.fill_diagonal(base_interactions, sii_exact.get_n_order_values(1))
                    wandb.log({
                        f"report/image_{true_idx}/base_interactions": pd.DataFrame(base_interactions, columns=range(len(base_interactions))),
                    })

                if year in shapiq_plot_years:
                    if year == 0:
                        prefix = f"report/image_{true_idx}/shnap_base/"
                    else:
                        prefix = f"report/image_{true_idx}/shnap_year{year}/"

                    wandb.log(shapiq_plots_factory.create_plots(
                        sii_exact * 100.0, shnap_game.player_names, prefix=prefix
                    ))

                    if config.exp.compute_banzhaf:
                        bv_exact = exact_computer(index="BV")
                        prefix = prefix.replace("shnap", "nbv")
                        wandb.log(shapiq_plots_factory.create_plots(
                            bv_exact * 100.0, shnap_game.player_names, prefix=prefix
                        ))

            shnaps = np.stack(shnap_values, axis=1)
            predictions = evaluator.get_predictions()
            deltas = predictions[1:] - predictions[0]
            breakdown = evaluator.breakdown(0)
            shnap_r2 = evaluator.sv_quality(shnaps, predictions[0])

            ########################################################
            # Logging to WandB
            ########################################################
            shnap_df = pd.DataFrame({
                "nodule": range(len(shnaps)),
                "base_shnap": shnaps[:, 0],
                "base_delta": deltas[:, 0]
            })

            metadata["nodules"] = n_regions
            metadata_df = pd.DataFrame([
                {'Key': str(k), 'Value': str(v)} for k, v in metadata.items()
            ])
            # For base risk
            deltas_plot_base = create_deltas_plot(deltas[:, 0], bg_start_idx=n_regions)
            risk_plot_base = create_risks_plot(predictions[:, 0], bg_start_idx=n_regions+1)
            shnaps_plot_base = create_shnaps_plot(shnaps[:, 0], bg_start_idx=n_regions+1)
            # For per year risk
            deltas_plot = create_deltas_plot(deltas[:, 1:], bg_start_idx=n_regions)
            risk_plot = create_risks_plot(predictions[:, 1:], bg_start_idx=n_regions+1)
            shnaps_plot = create_shnaps_plot(shnaps[:, 1:], bg_start_idx=n_regions+1)
            # Values table
            v_table = evaluator.get_all_values_as_df()
            # Breakdown plot
            breakdown_plot = shnap_utils.plot_breakdown(breakdown, scale_by=100.0)

            video_bright = np.clip(video.astype(np.uint16) * 4, 0, 255).astype(np.uint8)

            wandb.log({
                f"report/image_{true_idx}/video": wandb.Video(video, format="mp4"),
                f"report/image_{true_idx}/video_bright": wandb.Video(video_bright, format="mp4"),
                f"report/image_{true_idx}/shnap": wandb.Table(dataframe=shnap_df),
                f"report/image_{true_idx}/metadata": wandb.Table(dataframe=metadata_df),
                f"report/image_{true_idx}/deltas_plot_base": wandb.Plotly(deltas_plot_base),
                f"report/image_{true_idx}/risks_plot_base": wandb.Plotly(risk_plot_base),
                f"report/image_{true_idx}/shnaps_plot_base": wandb.Plotly(shnaps_plot_base),
                f"report/image_{true_idx}/risks_plot": wandb.Plotly(risk_plot),
                f"report/image_{true_idx}/deltas_plot": wandb.Plotly(deltas_plot),
                f"report/image_{true_idx}/shnaps_plot": wandb.Plotly(shnaps_plot),
                f"report/image_{true_idx}/v_table": wandb.Table(dataframe=v_table),
                f"report/image_{true_idx}/breakdown_plot": wandb.Plotly(breakdown_plot),
            })

            # Log per-nodule metrics
            for mask_true_idx in range(len(shnaps)):
                log_dict = {
                    "metrics/image": true_idx,
                    "metrics/nodule": mask_true_idx,
                    "metrics/attention_sum": region_attention_sum[mask_true_idx],
                    "metrics/attention_max": region_attention_max[mask_true_idx],
                    "metrics/base_shnap": shnaps[mask_true_idx][0],
                    "metrics/base_risk": predictions[1:][mask_true_idx][0],
                    "metrics/base_delta": deltas[mask_true_idx][0]
                }
                for i in range(1, shnaps.shape[1]):
                    log_dict[f"metrics/year{i}_shnap"] = shnaps[mask_true_idx][i]
                    log_dict[f"metrics/year{i}_delta"] = deltas[mask_true_idx][i]
                    log_dict[f"metrics/year{i}_risk"] = predictions[1:][mask_true_idx][i]

                wandb.log(log_dict)

            # Overall R2 per year
            log_dict = {}
            for year in range(classifier.num_years):
                log_dict[f"metrics/r2_year{year}"] = shnap_r2[year]
            wandb.log(log_dict)

            log.info('Batch done')

        inpainter.on_end()
        log.info('Experiment done')

    wandb.finish()
