import os
from dataclasses import asdict
from os.path import isfile, join
from typing import List, Optional, Tuple, Literal, Union

import torch
import numpy as np
import wandb
from einops import rearrange
from jaxtyping import jaxtyped, Float, Int, Shaped
from beartype import beartype as typechecker
from matplotlib import pyplot as plt
from omegaconf import OmegaConf
from pytorch_lightning.tuner.tuning import Tuner
from torchvision.utils import make_grid

from conf.dataset import ValueRange
from conf.main_config import GlobalConfiguration
from conf.model import LoggingParams, SubLoggingParams

colors = [   [  0,   0,   0],
             [128, 64, 128],
             [244, 35, 232],
             [70, 70, 70],
             [102, 102, 156],
             [190, 153, 153],
             [153, 153, 153],
             [250, 170, 30],
             [220, 220, 0],
             [107, 142, 35],
             [152, 251, 152],
             [0, 130, 180],
             [220, 20, 60],
             [255, 0, 0],
             [0, 0, 142],
             [0, 0, 70],
             [0, 60, 100],
             [0, 80, 100],
             [0, 0, 230],
             [119, 11, 32],
             ]

label_colours = dict(zip(range(len(colors)), colors))


@jaxtyped
@typechecker
def mask2rgb(
    mask: Int[torch.Tensor, 'b h w'],
    return_tensor: bool = True,
):
    """Mask should not be one-hot-encoded"""
    device = mask.device
    max_class = mask.max()

    mask = mask.cpu().numpy()
    r = mask.copy()
    g = mask.copy()
    b = mask.copy()

    for class_i in range(max_class + 1):
        r[mask == class_i] = label_colours[class_i][0]
        g[mask == class_i] = label_colours[class_i][1]
        b[mask == class_i] = label_colours[class_i][2]

    if len(mask.shape) == 2:
        h, w = mask.shape
        rgb = np.zeros([3, h, w])

        rgb[0, :, :] = r.reshape(h, w) / 255.0
        rgb[1, :, :] = g.reshape(h, w) / 255.0
        rgb[2, :, :] = b.reshape(h, w) / 255.0
    else:
        bs, h, w = mask.shape
        rgb = np.zeros([bs, 3, h, w])

        rgb[:, 0, :, :] = r.reshape(bs, h, w) / 255.0
        rgb[:, 1, :, :] = g.reshape(bs, h, w) / 255.0
        rgb[:, 2, :, :] = b.reshape(bs, h, w) / 255.0

    if return_tensor:
        return torch.Tensor(rgb).to(device)
    return rgb


@jaxtyped(typechecker=typechecker)
def rgb2mask(
    rgb: Shaped[torch.Tensor, '3 h w'],
    normalize_in: Optional[Literal["11", "01"]] = None,
    one_hot: bool = False,
    nb_class: Optional[int] = None,
) -> Union[Float[torch.Tensor, 'h w'], Float[torch.Tensor, 'c h w']]:
    if normalize_in == "01":
        rgb = rgb * 256
    elif normalize_in == "11":
        rgb = (rgb + 1) * 127.5

    rgb = rgb.int()
    _, h, w = rgb.shape

    device = rgb.device

    mask = torch.full((h, w), -1, dtype=torch.int32, device=device)
    # c = {torch.tensor(co): i for i, co in enumerate(colors)}
    for i, co in enumerate(colors):
        issame = (rgb == torch.tensor(co).reshape(3, 1, 1)).all(dim=0)
        mask[issame] = i

    assert (mask != -1).all(), f"Found -1 in mask: {mask=}"
    # (mask == -1).nonzero(as_tuple=True)

    if one_hot:
        assert nb_class is not None, f"Need to specify {nb_class=}"
        mask = torch.nn.functional.one_hot(mask.long(), num_classes=nb_class).permute(2, 0, 1)

    return mask.float()


def is_logging_time(logging_params: SubLoggingParams, current_epoch, batch_idx, stage) -> bool:
    if 'train' in stage:
        logging_frequencies = logging_params.frequencies[0]
        log_first = logging_params.log_first[0]
    elif 'valid' in stage:
        logging_frequencies = logging_params.frequencies[1]
        log_first = logging_params.log_first[1]
    elif 'test' in stage:
        logging_frequencies = logging_params.frequencies[2]
        log_first = logging_params.log_first[2]
    else:
        raise Exception(f'Unknown stage: {stage=}')

    if stage not in logging_params.stages:
        return False

    if logging_params.logging_mode is None:
        return False
    elif logging_params.logging_mode == 'epoch':
        return (current_epoch % logging_frequencies == 0) and (True if log_first else current_epoch != 0)
    elif logging_params.logging_mode == 'batch':
        return (batch_idx % logging_frequencies == 0) and (True if log_first else batch_idx != 0)
    else:
        raise Exception(f'Unknown logging mode: {logging_params.logging_mode=}')


def get_file(path: str):
    """
    Return the checkpoint from the _model_save folder for testing
    """
    if os.path.isdir(path):
        files = [f for f in os.listdir(path) if isfile(join(path, f))]
        best_models = [f for f in files if 'best_model' in f]
        best_model = max(best_models, key=lambda file: int(file.split('_')[0]))
        return join(path, best_model)
    else:
        return path


def display_tensor(tensor: torch.Tensor, unnormalize: bool = False, dpi: Optional[int] = None):
    """
    Debugging function to display tensor on screen
    """
    if unnormalize:
        tensor = (tensor + 1)/2
    if len(tensor.shape) == 4:  # there is the batch is the shape -> make a grid
        tensor = make_grid(tensor, padding=20)
    if dpi is not None:
        plt.figure(dpi=dpi)
    plt.imshow(tensor.permute(1, 2, 0).cpu())
    plt.show()


def display_mask(tensor: torch.Tensor):
    """
    Debugging function to display mask on screen
    """
    if 'FloatTensor' in tensor.type():
        tensor = torch.argmax(tensor, dim=1).unsqueeze(dim=1)
    tensor = mask2rgb(tensor)
    if len(tensor.shape) == 4:
        tensor = make_grid(tensor, padding=20)
    plt.imshow(tensor.permute(1, 2, 0).cpu())
    plt.show()


def normalize_value_range(tensor: torch.Tensor, value_range: ValueRange, clip: bool = False):
    if value_range == ValueRange.Zero:
        res = tensor
    elif value_range == ValueRange.ZeroUnbound:
        res = tensor
    elif value_range == ValueRange.One:
        res = (tensor + 1) / 2
    elif value_range == ValueRange.OneUnbound:
        res = (tensor + 1) / 2
    else:
        raise Exception(f'Unknown value range: {value_range=}')

    return res if not clip else torch.clamp(res, 0., 1.)


def get_undersample_indices(src_length: int, nb_samples: int, strategy: str = 'uniform', quad_factor: float = 0.8) -> List[int]:
    assert nb_samples > 0
    if nb_samples in [1, 2]:
        indices = [src_length - 1] if nb_samples == 1 else [0, src_length - 1]
        return indices

    if src_length <= nb_samples:
        return [i for i in range(src_length)]

    first = 0
    last = src_length - 1
    src_list = [i for i in range(1, src_length - 1)]
    nb_samples -= 2

    if strategy == 'uniform':
        step = len(src_list) // nb_samples
        time_steps = [i * step for i in range(0, nb_samples)]
    elif strategy == 'quad_start':
        time_steps = ((np.linspace(0, np.sqrt(len(src_list) * quad_factor), nb_samples)) ** 2).astype(int) + 1
        time_steps = [len(src_list) - i for i in time_steps][::-1]
    elif strategy == 'quad_end':
        time_steps = ((np.linspace(0, np.sqrt(len(src_list) * quad_factor), nb_samples)) ** 2).astype(int) + 1
        time_steps = time_steps.tolist()
    else:
        raise ValueError(f'{strategy=} is not an available discretization method.')

    return time_steps + [last]


def undersample_list(src_list: List, nb_samples: int, strategy: str = 'uniform', quad_factor: float = 0.8, return_indices: bool = False) -> list:
    assert nb_samples > 0
    if nb_samples in [1, 2]:
        indices = [len(src_list) - 1] if nb_samples == 1 else [0, len(src_list) - 1]
        res = [src_list[i] for i in indices]
        if return_indices:
            return res, indices
        else:
            return res

    if len(src_list) <= nb_samples:
        return src_list

    res = []
    first = src_list[0]
    last = src_list[-1]
    src_list = src_list[1:-1]
    nb_samples -= 2

    if strategy == 'uniform':
        step = len(src_list) // nb_samples
        time_steps = [i * step for i in range(0, nb_samples)]
    elif strategy == 'quad_start':
        time_steps = ((np.linspace(0, np.sqrt(len(src_list) * quad_factor), nb_samples)) ** 2).astype(int) + 1
        time_steps = [len(src_list) - i for i in time_steps][::-1]
    elif strategy == 'quad_end':
        time_steps = ((np.linspace(0, np.sqrt(len(src_list) * quad_factor), nb_samples)) ** 2).astype(int) + 1
        time_steps = time_steps.tolist()
    else:
        raise ValueError(f'{strategy=} is not an available discretization method.')

    for i in time_steps:
        res.append(src_list[i])

    res = [first] + res + [last]

    if return_indices:
        return res, [0] + time_steps + [len(src_list) - 1]
    else:
        return res


@jaxtyped
@typechecker
def broadcast_modes_to_pixels(
    datas: List[Float[torch.Tensor, 'b _ci h w']],
    modes: Float[torch.Tensor, 'b n_dom'],
) -> Float[torch.Tensor, 'b c h w']:
    n_dom = len(datas)
    mode_per_dom = []
    for dom_i in range(n_dom):
        _, c, h, w = datas[dom_i].shape
        mode_i = modes[:, dom_i]
        mode_per_dom.append(mode_i.view(-1, 1, 1, 1).repeat(1, c, h, w))
    modes = torch.cat(mode_per_dom, dim=1)

    return modes


@jaxtyped
@typechecker
def broadcast_modes_to_pixels_shape(
    b,
    n_dom,
    c,
    h,
    w,
    modes: Float[torch.Tensor, 'b n_dom'],
) -> Float[torch.Tensor, 'b c h w']:
    mode_per_dom = []
    for dom_i in range(n_dom):
        mode_i = modes[:, dom_i]
        mode_per_dom.append(mode_i.view(-1, 1, 1, 1).repeat(1, c, h, w))
    modes = torch.cat(mode_per_dom, dim=1)

    return modes


def learning_rate_finder(cfg: GlobalConfiguration, trainer, model, data_module):
    if not cfg.trainer_params.learning_rate_finder_params.auto_lr_find:
        return

    lr_finder_params = cfg.trainer_params.learning_rate_finder_params
    tuner = Tuner(trainer)
    lr_finder = tuner.lr_find(
        model=model,
        datamodule=data_module,
        **asdict(lr_finder_params.pl_params),
    )
    results = lr_finder.results
    for lr, loss in zip(results['lr'], results['loss']):
        wandb.log({
            'lr_finder/lr': lr, 'lr_finder/loss': loss,
        })

    suggestion = lr_finder.suggestion()
    wandb.summary['lr_finder/suggestion'] = suggestion
    if lr_finder_params.pick_suggestion:
        print(f'Pick suggestion {suggestion=}')
        if suggestion is not None:
            new_lr = suggestion
            model.learning_rate = new_lr
    else:
        print(f'Do not pick suggestion {suggestion=}')

    if cfg.trainer_params.learning_rate_finder_params.exit_after_pick:
        wandb.finish()
        quit()


def batch_size_finder(cfg: GlobalConfiguration, trainer, model, data_module):
    if not cfg.trainer_params.batch_size_finder_params.auto_batch_size_finder:
        return

    bs_finder_params = cfg.trainer_params.batch_size_finder_params
    tuner = Tuner(trainer)
    found_batch_size = tuner.scale_batch_size(
        model=model,
        datamodule=data_module,
        **asdict(bs_finder_params.pl_params),
    )
    wandb.summary['batch_size_finder/found_batch_size'] = found_batch_size

    if cfg.trainer_params.batch_size_finder_params.exit_after_pick:
        wandb.finish()
        quit()


@jaxtyped
@typechecker
def augmentWithBackground(segmentations_maps: Float[torch.Tensor, 'b c h w']) -> Float[torch.Tensor, 'b c+1 h w']:
    """
    Compute the background channel and add it at the start of the segmentation maps.
    """
    background = (segmentations_maps < 0.5).all(dim=1, keepdim=True).float()
    return torch.cat([background, segmentations_maps], dim=1)


def read_list_from_file(file_path: str) -> list[int]:
    with open(file_path, 'r') as file:
        content = file.read()
        integer_list = [int(x) for x in content.split(',') if x.strip()]
    return integer_list


def write_list_to_file(file_path: str, integer_list: list[int]):
    if os.path.isfile(file_path):
        raise ValueError(f'File {file_path} already exist.')
    with open(file_path, 'w') as file:
        content = ','.join(str(x) for x in integer_list)
        file.write(content)


def freeze(model):
    for param in model.parameters():
        param.requires_grad = False


def get_hack_mode(hack_mode: Tuple[Tuple[int, ...]], modes_init: Float[torch.Tensor, 'b n_dom']) -> Float[torch.Tensor, 'b n_dom']:
    b, n_dom = modes_init.shape
    modes = []
    for _ in range(b):
        modes.append(hack_mode[np.random.randint(0, len(hack_mode))])

    modes = torch.tensor(modes, device=modes_init.device, dtype=modes_init.dtype).reshape(b, n_dom)

    return modes


def norm_minmax(x: Float[torch.Tensor, 'b cndom h w'], dim_per_dom: list[int]) -> Float[torch.Tensor, 'b cndom h w']:
    """
    Put whatever data into [0;1] range, linear
    """
    per_dom = list(x.split(dim_per_dom, dim=1))
    b = x.shape[0]
    for i in range(len(per_dom)):
        domi = per_dom[i]
        min_per_sample = domi.reshape([b, -1]).min(dim=1)[0].view(b, 1, 1, 1)
        max_per_sample = domi.reshape([b, -1]).max(dim=1)[0].view(b, 1, 1, 1)
        per_dom[i] = (domi - min_per_sample) / (max_per_sample - min_per_sample)

    return torch.cat(per_dom, dim=1)


def norm_max(x: Float[torch.Tensor, 'b cndom h w'], dim_per_dom: list[int]) -> Float[torch.Tensor, 'b cndom h w']:
    """
    Put whatever data into [0;1], clamp min to 0, linear scale max
    """
    x = x.clamp(min=0.)
    x = norm_minmax(x, dim_per_dom)
    return x
