import os
import os.path as osp

import numpy as np
import torch
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
from mmengine.runner.checkpoint import save_checkpoint
from mmengine.evaluator import Evaluator

from ..model.wrapped_models import WrappedModels
from ..model.wrapped_networks import WrappedEncoderDecoder
from ..utils.local_activation_checkpointing import turn_on_activation_checkpointing

from typing import List, Dict
from copy import deepcopy

from mmseg.datasets import BaseSegDataset
from tqdm import tqdm

from mmseg.structures import SegDataSample
from mmseg.models.utils.wrappers import resize

from einops import rearrange

from mmseg.visualization import SegLocalVisualizer
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from mmengine.logging import MMLogger, print_log
from PIL import Image
from itertools import chain

class BaseTTARunner(Runner):
    def __init__(self, cfg):
        # initialize model, logger, hook and so on
        super().__init__(
            model=cfg['model'],
            work_dir=cfg['work_dir'],
            train_dataloader=None,
            val_dataloader=None,
            test_dataloader=None,
            train_cfg=None,
            val_cfg=None,
            test_cfg=None,
            auto_scale_lr=None,
            optim_wrapper=None,
            param_scheduler=None,
            val_evaluator=None,
            test_evaluator=None,
            default_hooks=cfg.get('default_hooks'),
            custom_hooks=cfg.get('custom_hooks'),
            data_preprocessor=cfg.get('data_preprocessor'),
            load_from=cfg.get('load_from'),
            resume=cfg.get('resume', False),
            launcher=cfg.get('launcher', 'none'),
            env_cfg=cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))),
            log_processor=cfg.get('log_processor'),
            log_level=cfg.get('log_level', 'INFO'),
            visualizer=cfg.get('visualizer'),
            default_scope=cfg.get('default_scope', 'mmengine'),
            randomness=cfg.get('randomness', dict(seed=None)),
            experiment_name=cfg.get('experiment_name'),
            cfg=cfg,
        )

        self.optim_wrapper = cfg.get("tta_optim_wrapper")
        self.tasks = cfg.get("tasks")
        self.data_loader = cfg.get("tta_data_loader")
        self.continual = cfg.get("continual")
        self.evaluator = cfg.get("tta_evaluator")

        # init the model's weight
        self._init_model_weights()
        # configure the model
        # set parameters needs update with requires_grad=True, vice versa
        # modify BN and so on
        self.config_tta_model()

        self.state_dict = deepcopy(self.model.state_dict())

        if is_model_wrapper(self.model):
            ori_model = self.model.module
        else:
            ori_model = self.model
        # try to enable activation_checkpointing feature
        modules = cfg.get('activation_checkpointing', None)
        if modules is not None:
            self.logger.info(f'Enabling the "activation_checkpointing" feature'
                             f' for sub-modules: {modules}')
            turn_on_activation_checkpointing(ori_model, modules)

        # build optimizer
        self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)

        self.tasks = self.build_tta_tasks(self.tasks)

        self.tta_visualizer = SegLocalVisualizer(vis_backends=dict(type='LocalVisBackend'), save_dir=self.work_dir)
        self.visualizing = cfg.get("visualizing", False)

        self.selected_dir = osp.join(self.work_dir, "selected")
        os.makedirs(self.selected_dir, exist_ok=True)

        self.score_dir = osp.join(self.work_dir, "active_score")
        os.makedirs(self.score_dir, exist_ok=True)

        self.back_test = cfg.get("back_test", False)

        self.debug = cfg.get("debug", False)

    def config_tta_model(self):
        pass

    @staticmethod
    def build_tta_tasks(tasks):
        """
        format the tasks, it should be a list of dict, each elements represents a dataset to perform test-time adaptation
        :param tasks: Dict or List[Dict], or List[dataset]
        :return: List[Dict] or List[dataset]
        """
        if isinstance(tasks, dict) or isinstance(tasks, BaseSegDataset):
            tasks = [tasks]  # single task

        if isinstance(tasks, list):
            return tasks
        else:
            raise TypeError

    def reset_model(self):
        self.logger.info("Fully Test-time Adaptation: Resetting the model!")
        self.model.load_state_dict(self.state_dict)

    def tta(self):
        all_metric = []
        for i, task in enumerate(self.tasks):
            self.set_randomness(**self._randomness_cfg)
            if not self.continual:
                self.reset_model()
            metric = self.perform_one_task(task, f"[{i}][{len(self.tasks)}]")
            self.logger.info(f"Task {i}: mIoU: {metric['mIoU']}")
            all_metric.append(metric['mIoU'])
            if self.cfg.get("save_checkpoint", False):
                task_name = task.data_prefix.img_path.split('/')[1]+'.pth'
                ckpt_path = os.path.join(self.cfg.work_dir, task_name)
                if isinstance(self.model, WrappedEncoderDecoder):
                    save_checkpoint(self.model.state_dict(), ckpt_path)
                elif isinstance(self.model, WrappedModels):
                    save_checkpoint(self.model.task_model.state_dict(), ckpt_path)
                self.logger.info(f"{task_name} is saved in {ckpt_path}")

        self.logger.info("mIoU summary: " + "\t".join([f"{mIoU:.2f}" for mIoU in all_metric]))
        self.logger.info(f"Average: {sum(all_metric)/len(all_metric)}")

        if self.back_test:
            self.back_test_all_tasks()

    def perform_one_task(self, task, task_name=""):
        evaluator: Evaluator = self.build_evaluator(self.evaluator)
        # without data is also ok
        data_loader = deepcopy(self.data_loader)
        data_loader['dataset'] = task
        data_loader = self.build_dataloader(dataloader=data_loader)
        if hasattr(data_loader.dataset, 'metainfo'):
            evaluator.dataset_meta = data_loader.dataset.metainfo
            self.tta_visualizer.set_dataset_meta(data_loader.dataset.metainfo["classes"], data_loader.dataset.metainfo["palette"])

        tbar = tqdm(data_loader)

        # for online metric close logger info
        logger: MMLogger = MMLogger.get_current_instance()
        logger.setLevel('ERROR')
        # 500 for each task, consistent with cotta
        for i, batch_data in enumerate(tbar):
            self.tta_one_batch(batch_data, evaluator)
            online_metrics = evaluator.metrics[0].compute_metrics(evaluator.metrics[0].results)

            tbar.set_postfix(online_metrics)

            all_scalars = dict()
            for k, v in chain(online_metrics.items(), online_metrics.items()):
                new_k = f"{task_name}:{k}"
                all_scalars[new_k] = v
            self.visualizer.add_scalars(all_scalars, step=i)
        logger.setLevel('INFO')

        task_matrics = evaluator.evaluate(len(data_loader.dataset))
        return task_matrics

    def tta_one_batch(self, batch_data, evaluator: Evaluator):
        raise NotImplementedError

    def back_test_all_tasks(self):
        all_metric = []
        for i, task in enumerate(self.tasks):
            metric = self.test_one_task(task, f"[{i}][{len(self.tasks)}]")
            self.logger.info(f"Back-Test Task {i}: mIoU: {metric['mIoU']}")
            all_metric.append(metric['mIoU'])

        self.logger.info("Back-Test mIoU summary: " + "\t".join([f"{mIoU:.2f}" for mIoU in all_metric]))
        self.logger.info(f"Average: {sum(all_metric) / len(all_metric)}")

    def test_one_task(self, task, task_name=""):
        evaluator: Evaluator = self.build_evaluator(self.evaluator)
        # without data is also ok
        data_loader = deepcopy(self.data_loader)
        data_loader['dataset'] = task
        data_loader = self.build_dataloader(dataloader=data_loader)
        if hasattr(data_loader.dataset, 'metainfo'):
            evaluator.dataset_meta = data_loader.dataset.metainfo

        tbar = tqdm(data_loader)

        # 500 for each task, consistent with cotta
        for i, batch_data in enumerate(tbar):
            self.test_one_batch(batch_data, evaluator)

        task_matrics = evaluator.evaluate(len(data_loader.dataset))
        return task_matrics

    @torch.no_grad()
    def test_one_batch(self, batch_data, evaluator):
        raise NotImplementedError

    def visualize(self, data_samples, selected_indices=None, active_scores=None):
        if not self.visualizing:
            return
        for idx, data_sample in enumerate(data_samples):
            # img = batch_data['inputs'].numpy()
            # img = img.transpose(1, 2, 0)
            # img, w_scale, h_scale = mmcv.imresize(
            #     img,
            #     data_samples.gt_seg_map.data.shape,
            #     interpolation=self.interpolation,
            #     return_scale=True,
            #     backend=self.backend)
            # img = mmcv.imread(
            #     data_sample.img_path,
            #     'color')
            img = np.zeros((data_sample.gt_sem_seg.data.shape[1], data_sample.gt_sem_seg.data.shape[2], 3))
            p_l = data_sample.img_path.split(".")[0].split("/")
            # only work for acdc
            path = p_l[3] + "_" + p_l[6]
            self.tta_visualizer.add_datasample(path, img, data_sample, show=False, withLabels=False)
            if selected_indices is not None:
                select = selected_indices[idx]
                s_path = osp.join(self.selected_dir, f"{path}.pth")
                torch.save(select, s_path)

            if active_scores is not None and len(active_scores) > 0:
                score = active_scores[idx]
                s_path = osp.join(self.score_dir, f"{path}.png")

                # heatmap_data = (score - torch.min(score)) / (torch.max(score) - torch.min(score))
                # rgb_image = heatmap_to_rgb(heatmap_data.cpu().numpy())
                heatmap_data = (score - torch.min(score)) / (torch.max(score) - torch.min(score))
                heatmap_data_np = heatmap_data.cpu().numpy()
                norm = Normalize(vmin=heatmap_data_np.min(), vmax=heatmap_data_np.max())
                sm = ScalarMappable(cmap='viridis', norm=norm)
                rgb_image = sm.to_rgba(heatmap_data_np)

                rgb_image_scaled = (rgb_image * 255).astype(np.uint8)
                pil_image = Image.fromarray(rgb_image_scaled)
                pil_image.save(s_path)

    @classmethod
    def from_cfg(cls, cfg) -> 'Runner':
        return cls(cfg)

    def xxxxxxxxxx(self, data_samples: List[SegDataSample]):
        # post process on original ground truth, (LoadAnnotation is directly followed by PackSegInputs)
        # only resize and flip are supported
        all_labels = []
        for data_sample in data_samples:
            img_meta = data_sample.metainfo
            # current_gt.shape: (1, h, w)
            current_gt = deepcopy(data_sample.gt_sem_seg.get("data"))  # data_sample.gt_sem_seg.data also is ok

            flip = img_meta.get('flip', None)
            if flip:
                flip_direction = img_meta.get('flip_direction', None)
                assert flip_direction in ['horizontal', 'vertical']
                if flip_direction == 'horizontal':
                    current_gt = current_gt.flip(dims=(2,))
                else:
                    current_gt = current_gt.flip(dims=(1,))

            current_gt = resize(
                current_gt,
                size=img_meta['img_shape'],
                mode='nearest',
                align_corners=self.model.align_corners,
                warning=False).squeeze(0)

            all_labels.append(current_gt)

        return torch.stack(all_labels, dim=0)

    @staticmethod
    def build_ema_model(model: torch.nn.Module):
        ema = deepcopy(model)
        ema.requires_grad_(False)
        return ema

    @staticmethod
    def update_ema_variables(ema_model, model, alpha_teacher):
        with torch.no_grad():
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                ema_param.mul_(alpha_teacher).add_(param, alpha=1 - alpha_teacher)
        return ema_model


def softmax_cross_entropy(logits: torch.Tensor, target_logits: torch.Tensor, weights=None) -> torch.Tensor:
    # logits: (B, C, H, W)
    logits = logits.contiguous()
    target_logits = target_logits.contiguous()
    flattened_logits = rearrange(logits, "b c h w -> (b h w) c")
    flattened_target = rearrange(target_logits, "b c h w -> (b h w) c")
    entropy_map = torch.sum(-flattened_target.softmax(1) * flattened_logits.log_softmax(1), dim=1)
    if weights is None:
        entropy_map = entropy_map
    else:
        weights = rearrange(weights, "b h w -> (b h w)")
        entropy_map = entropy_map * weights
    return torch.mean(entropy_map)


def cross_entropy(logits: torch.Tensor, target_label: torch.Tensor, weights=None) -> torch.Tensor:
    # logits: (B, C, H, W)
    logits = logits.contiguous()
    target_label = target_label.contiguous()
    flattened_logits = rearrange(logits, "b c h w -> (b h w) c")
    flattened_target = rearrange(target_label, "b c h w -> (b h w) c")
    entropy_map = torch.sum(-flattened_target * flattened_logits.log_softmax(1), dim=1)
    if weights is None:
        entropy_map = entropy_map
    else:
        weights = rearrange(weights, "b h w -> (b h w)")
        entropy_map = entropy_map * weights
    return torch.mean(entropy_map)


def heatmap_to_rgb(heatmap):
    colored_heatmap = np.stack([np.zeros_like(heatmap), np.zeros_like(heatmap), heatmap], axis=-1)
    return colored_heatmap

