import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.utils.data
#from visdom_logger import VisdomLogger
from collections import defaultdict
from .dataset.dataset import get_val_loader
from .util import AverageMeter, batch_intersectionAndUnionGPU, get_model_dir, main_process, \
                  batch_vid_consistencyGPU, compute_map

from .util import find_free_port, setup, cleanup, to_one_hot, intersectionAndUnionGPU
from .model.model import get_model
import torch.distributed as dist
from tqdm import tqdm
from .util import load_cfg_from_cfg_file, merge_cfg_from_list
import argparse
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import time
from .visu import make_episode_visualization, make_episode_visualization_cv2, \
                  make_keyframes_vis
from typing import Tuple

def parse_args() -> None:
    parser = argparse.ArgumentParser(description='Testing')
    parser.add_argument('--config', type=str, required=True, help='config file')
    parser.add_argument('--opts', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()

    assert args.config is not None
    cfg = load_cfg_from_cfg_file(args.config)
    if args.opts is not None:
        cfg = merge_cfg_from_list(cfg, args.opts)

    cfg.config_name = args.config
    return cfg

def main_worker(rank: int,
                world_size: int,
                args: argparse.Namespace) -> None:

    print(f"==> Running DDP checkpoint example on rank {rank}.")
    setup(args, rank, world_size)

    if args.manual_seed is not None:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.cuda.manual_seed(args.manual_seed + rank)
        np.random.seed(args.manual_seed + rank)
        torch.manual_seed(args.manual_seed + rank)
        torch.cuda.manual_seed_all(args.manual_seed + rank)
        random.seed(args.manual_seed + rank)

    # ========== Model  ==========
    model = get_model(args).to(rank)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[rank])

    root = get_model_dir(args)

    if args.ckpt_used is not None:
        filepath = os.path.join(root, f'{args.ckpt_used}.pth')
        assert os.path.isfile(filepath), filepath
        print("=> loading weight '{}'".format(filepath))
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loaded weight '{}'".format(filepath))
    else:
        print("=> Not loading anything")

    # ========== Data  ==========
    val_loader, _ = get_val_loader(args, split_type='test')

    # ========== Test  ==========
    if args.episodic_val or args.temporal_episodic_val:
        val_Iou, val_loss = episodic_validate(args=args,
                                              val_loader=val_loader,
                                              model=model,
                                              use_callback=(args.visdom_port != -1),
                                              suffix=f'test')
    if args.distributed:
        dist.all_reduce(val_Iou), dist.all_reduce(val_loss)
        val_Iou /= world_size
        val_loss /= world_size



def episodic_validate(args: argparse.Namespace,
                      val_loader: torch.utils.data.DataLoader,
                      model: DDP,
                      use_callback: bool,
                      suffix: str = 'test') -> Tuple[torch.tensor, torch.tensor]:

    print('==> Start testing')

    model.eval()

    # ========== Metrics initialization  ==========
    all_weights = {'notti': [0.0, 0.0]}
    if hasattr(args,'selected_weights') and len(args.selected_weights) != 0:
        all_weights = {'quickval': args.selected_weights}

    # Only used by quick val
    skip = 5
    max_frames = 20

    runtimes = {k: torch.zeros(args.n_runs) for k in all_weights.keys()}
    val_IoUs = {k: np.zeros(args.n_runs) for k in all_weights.keys()}
    val_Fscores = {k: np.zeros(args.n_runs) for k in all_weights.keys()}
    val_VCs = {}
    for method in all_weights.keys():
        val_VCs[method] = {kwin: np.zeros(args.n_runs) for kwin in args.vc_wins}
    val_losses = {k: np.zeros(args.n_runs) for k in all_weights.keys()}

    if hasattr(args, 'single_proto_flag') and args.single_proto_flag:
        assert args.tloss_type == "fb_consistency", "This temporal loss doesnt work with single proto"

    seq_ious = {}

    # ========== Perform the runs  ==========
    for run in tqdm(range(args.n_runs)):

        # =============== Initialize the metric dictionaries ===============
        iter_num = 0
        cls_intersection = {k: defaultdict(int) for k in all_weights.keys()} # Default value is 0
        cls_union = {k: defaultdict(int)  for k in all_weights.keys()}
        cls_vc = {k: defaultdict(int)  for k in all_weights.keys()}
        cls_n_vc = {k: defaultdict(int)  for k in all_weights.keys()}

        IoU = {k: defaultdict(int) for k in all_weights.keys()}
        Fscores = {k: defaultdict(int) for k in all_weights.keys()}

        # =============== episode = group of tasks ===============
        runtime = {k: 0  for k in all_weights.keys()}
        for qry_img, q_label, spprt_imgs, s_label, subcls, misc, paths  in tqdm(val_loader):
            t0 = time.time()
            if 'quickval' in all_weights:
               qry_img = qry_img[:, ::skip]
               q_label = q_label[:, ::skip]

               qry_img = qry_img[:, :max_frames]
               q_label = q_label[:, :max_frames]

            # =========== Generate tasks and extract features for each task ===============
            all_sprt = {'imgs': [], 'masks': [], 'paths': []}
            all_qry = {'imgs': [], 'masks': [], 'flows': [], 'paths': None}

            spprt_imgs = spprt_imgs.to(dist.get_rank(), non_blocking=True)
            s_label = s_label.to(dist.get_rank(), non_blocking=True)
            qry_img = qry_img.to(dist.get_rank(), non_blocking=True)
            q_label = q_label.to(dist.get_rank(), non_blocking=True)

            Nframes = qry_img.size(1)
            shot = spprt_imgs.size(1)
            iter_num += Nframes

            classes = [class_.item() for class_ in subcls] * Nframes
            seqs = np.array(misc * Nframes)
            gt_q = q_label.permute(1,0,2,3)

            if args.visu:
                all_sprt['imgs'] = spprt_imgs.cpu().numpy()
                all_sprt['masks'] = s_label.cpu().numpy()
                all_sprt['paths'] = [p[0] for p in paths]

                all_qry['imgs'] = qry_img.cpu().numpy()
                all_qry['masks'] = q_label.cpu().numpy()
                all_qry['paths'] = os.path.join(val_loader.dataset.img_dir, misc[0])
                all_qry['paths'] = [os.path.join(all_qry['paths'], fname) for fname in sorted(os.listdir(all_qry['paths'])) ]
                if 'quickval' in all_weights:
                    all_qry['paths'] = all_qry['paths'][::skip]
                    all_qry['paths'] = all_qry['paths'][:max_frames]

            for method, weights in all_weights.items():
                with torch.no_grad():
                    # Extract Features or predictions directly based on the method
                    assert qry_img.shape[0] == 1, "Allow only batch 1 query set with video frames, qry_img was is BxTxCxHxW"
                    probas = model.module.predict_mask_nshot(qry_img, spprt_imgs, s_label, misc)

                # Evaluate the results
                intersection, union, _ = batch_intersectionAndUnionGPU(probas, gt_q, 2)  # [n_tasks, shot, num_class]
                intersection, union = intersection.cpu(), union.cpu()

                # ================== Log metrics ==================
                one_hot_gt = to_one_hot(gt_q, 2)
                valid_pixels = gt_q != 255

                visited_seqs = []
                for i, class_ in enumerate(classes):
                    cls_intersection[method][class_] += intersection[i, 0, 1]  # Do not count background
                    cls_union[method][class_] += union[i, 0, 1]

                for class_ in cls_union[method]:
                    IoU[method][class_] = cls_intersection[method][class_] / (cls_union[method][class_] + 1e-10)

                seq_ious[f'run_{run}_iter_{iter_num}'] = intersection[i, 0, 1] / union[i, 0, 1]

                # ================== Visualization ==================
                if args.visu:
                    for i in range(Nframes):
                        root = os.path.join(args.vis_dir, 'episodes', method, 'split_%d'%args.train_split)
                        os.makedirs(root, exist_ok=True)
                        save_path = os.path.join(root, f'run_{run}_iter_{iter_num}_{i:05d}.png')

                        make_episode_visualization_cv2(img_s=all_sprt['imgs'][0].copy(),
                                                       img_q=all_qry['imgs'][0, i].copy(),
                                                       gt_s=all_sprt['masks'][0].copy(),
                                                       gt_q=all_qry['masks'][0, i].copy(),
                                                       path_s=all_sprt['paths'],
                                                       path_q=all_qry['paths'][i],
                                                       preds=probas[i].cpu().numpy().copy(),
                                                       save_path=save_path,
                                                       flow_q=None)

        # ================== Evaluation Metrics on ALl episodes ==================
        print('========= Method {}==========='.format(method))
        runtimes[method][run] = runtime[method] / float(len(val_loader))
        mIoU = np.mean(list(IoU[method].values()))

        print('mIoU---Val result: mIoU {:.4f}.'.format(mIoU))
        for class_ in cls_union[method]:
            print("Class {} : {:.4f}".format(class_, IoU[method][class_]))
        val_IoUs[method][run] = mIoU

    # ================== Save metrics ==================
    for method in all_weights.keys():
        str_weights = str(all_weights[method])
        print(f'========Final Evaluation of {method} with weights {str_weights}============')
        print('Average mIoU over {} runs --- {:.4f}.'.format(args.n_runs, val_IoUs[method].mean()))
        print('Average runtime / seq --- {:.4f}.'.format(runtimes[method].mean()))

    # This method works on multiple weights can not be used outside
    if 'quickval' in all_weights:
        return torch.tensor(np.mean(list(IoU['quickval'].values()))).to(dist.get_rank()), \
                    torch.tensor(np.mean(val_losses['quickval'])).to(dist.get_rank())
    return None, None


if __name__ == "__main__":
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpus)

    if args.debug:
        args.test_num = 500
        args.n_runs = 2

    world_size = len(args.gpus)
    distributed = world_size > 1
    args.distributed = distributed
    args.port = find_free_port()
    main_worker(0, world_size, args)
