from models.planner import PtMLPClsApproxPlanner
import os
import logging
import pickle
import numpy as np
import scipy.stats
import torch
import torch.nn.functional as F
import tensorflow as tf
from collections import defaultdict
from pprint import pformat

from models import get_model
from utils.hparams import HParams
from datasets.mri_utils import *

def get_wrapper(hps):
    model_cfg = HParams(hps.model_cfg_file)
    if model_cfg.model == 'acflow_classifier':
        if 'flat' in hps.agent or 'wolp' in hps.agent:
            return FclsModelWrapper_flat(hps)
        return FclsModelWrapper(hps)
    elif model_cfg.model == 'classifier':
        return ClsModelWrapper(hps)
    elif model_cfg.model == 'acflow':
        return FairModelWrapper(hps)
    else:
        raise NotImplementedError()


class ClsModelWrapper(object):
    def __init__(self, hps):
        self.hps = hps
        H, W, C = hps.image_shape

        model_cfg = HParams(hps.model_cfg_file)
        self.model = get_model(model_cfg)
        self.model.load()

        planner_cfg = HParams(hps.planner_cfg_file)
        self.num_groups = planner_cfg.num_clusters
        self.group_size = H * W // self.num_groups

        with open(f'{planner_cfg.exp_dir}/results.pkl', 'rb') as f:
            labels = pickle.load(f)['labels']
        group2idx = defaultdict(list)
        idx2group = dict()
        group_count = [0]*self.num_groups
        for i, lab in enumerate(labels):
            lab = int(lab)
            group2idx[lab].append(i)
            idx2group[i] = (lab, group_count[lab])
            group_count[lab] = group_count[lab] + 1
        self.group2idx = group2idx
        self.idx2group = idx2group

        logging.info('>'*30)
        logging.info('action groups -> pixel index:')
        logging.info(f'\n{pformat(self.group2idx)}\n')
        logging.info('pixel index -> (action groups, action index):')
        logging.info(f'\n{pformat(self.idx2group)}\n')
        logging.info('<'*30)

    def predict(self, state, mask, return_prob=False):
        prob = self.model.execute(self.model.prob,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.is_training: False})
        if return_prob: return prob
        return np.argmax(prob, axis=-1)

    @property
    def auxiliary_shape(self):
        H, W, C = self.hps.image_shape
        K = self.hps.num_classes
        return [H,W,K]

    def get_auxiliary(self, state, mask):
        B = mask.shape[0]
        H, W, C = self.hps.image_shape
        K = self.hps.num_classes
        
        # probability
        prob = self.model.execute(self.model.prob,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.is_training: False})
        prob = np.repeat(np.repeat(np.reshape(prob, [B,1,1,K]), H, axis=1), W, axis=2)

        # auxiliary
        aux = prob
        
        return aux

    def get_availability(self, state, mask):
        '''1: available    0: occupied'''
        B, H, W, C = mask.shape
        avail_action = np.ones([mask.shape[0], self.num_groups, self.group_size], dtype=np.bool)
        for i, m in enumerate(mask):
            m = m[...,0].reshape(H*W)
            indexes = np.where(m)[0]
            for idx in indexes:
                avail_action[i, self.idx2group[idx][0], self.idx2group[idx][1]] = 0
        avail_group = np.sum(~avail_action, axis=2) != self.group_size
        
        return avail_group, avail_action
        
    def get_action(self, group, action):
        pixel_idx = []
        for g, a in zip(group, action):
            pixel_idx.append(self.group2idx[g][a])
        return np.array(pixel_idx, dtype=np.int32)

    def get_reward(self, state, mask, action, next_state, next_mask, done, target):
        # information gain
        pre_ent = self.model.execute(self.model.ent,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.is_training: False})
        post_ent = self.model.execute(self.model.ent,
                    feed_dict={self.model.x: next_state,
                               self.model.b: next_mask,
                               self.model.is_training: False})
        cmi = pre_ent - post_ent
        
        if not np.all(done): return cmi
        
        # prediction reward
        xent = self.model.execute(self.model.xent,
                    feed_dict={self.model.x: next_state,
                               self.model.b: next_mask,
                               self.model.y: target,
                               self.model.is_training: False})
        
        return cmi - xent

class FclsModelWrapper_flat(object):
    def __init__(self, hps):
        self.hps = hps
        H, W, C = hps.image_shape
        self.num_actions = H*W
        
        model_cfg = HParams(hps.model_cfg_file)
        self.model = get_model(model_cfg)
        self.model.load()

    def predict(self, state, mask, return_prob=False):
        prob = self.model.execute(self.model.prob,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.m: mask})
        if return_prob: return prob
        return np.argmax(prob, axis=-1)

    @property
    def auxiliary_shape(self):
        H, W, C = self.hps.image_shape
        K = self.hps.num_classes
        return [H,W,K+C*4]

    def get_auxiliary(self, state, mask):
        B = mask.shape[0]
        H, W, C = self.hps.image_shape
        K = self.hps.num_classes
        N = self.hps.num_samples
        
        # probability
        prob = self.model.execute(self.model.prob,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.m: mask})
        prob = np.repeat(np.repeat(np.reshape(prob, [B,1,1,K]), H, axis=1), W, axis=2)
        
        # sample p(x_u | x_o) and p(x_u | x_o, pred)
        state_extended = np.repeat(np.expand_dims(state, axis=1), N, axis=1)
        state_extended = np.reshape(state_extended, [B*N,H,W,C])
        mask_extended = np.repeat(np.expand_dims(mask, axis=1), N, axis=1)
        mask_extended = np.reshape(mask_extended, [B*N,H,W,C])
        sam, pred_sam = self.model.execute([self.model.sam, self.model.pred_sam],
                    feed_dict={self.model.x: state_extended,
                               self.model.b: mask_extended,
                               self.model.m: np.ones_like(mask_extended)})
        sam = np.reshape(sam, [B,N,H,W,C]) / 255.
        pred_sam = np.reshape(pred_sam, [B,N,H,W,C]) / 255.
        
        sam_mean = np.mean(sam, axis=1)
        sam_std = np.std(sam, axis=1)
        pred_sam_mean = np.mean(pred_sam, axis=1)
        pred_sam_std = np.std(pred_sam, axis=1)
        
        # auxiliary
        aux = np.concatenate([prob, sam_mean, sam_std, pred_sam_mean, pred_sam_std], axis=-1)
        
        return aux

    def get_availability(self, state, mask):
        '''1: available    0: occupied'''
        B, H, W, C = mask.shape
        return np.logical_not(mask[:,:,:,0].reshape([B,H*W]).astype(np.bool))

    def get_reward(self, state, mask, action, next_state, next_mask, done, target):
        # information gain
        pre_ent = self.model.execute(self.model.ent,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.m: mask})
        post_ent = self.model.execute(self.model.ent,
                    feed_dict={self.model.x: next_state,
                               self.model.b: next_mask,
                               self.model.m: next_mask})
        cmi = pre_ent - post_ent
        
        if not np.all(done): return cmi
        
        # prediction reward
        xent = self.model.execute(self.model.xent,
                    feed_dict={self.model.x: next_state,
                               self.model.b: next_mask,
                               self.model.m: next_mask,
                               self.model.y: target})
        
        return cmi - xent
    
class FclsModelWrapper(object):
    def __init__(self, hps):
        self.hps = hps
        H, W, C = hps.image_shape
        
        model_cfg = HParams(hps.model_cfg_file)
        self.model = get_model(model_cfg)
        self.model.load()
        
        planner_cfg = HParams(hps.planner_cfg_file)
        self.num_groups = planner_cfg.num_clusters
        self.group_size = H * W // self.num_groups
        
        with open(f'{planner_cfg.exp_dir}/results.pkl', 'rb') as f:
            labels = pickle.load(f)['labels']
        group2idx = defaultdict(list)
        idx2group = dict()
        group_count = [0]*self.num_groups
        for i, lab in enumerate(labels):
            lab = int(lab)
            group2idx[lab].append(i)
            idx2group[i] = (lab, group_count[lab])
            group_count[lab] = group_count[lab] + 1
        self.group2idx = group2idx
        self.idx2group = idx2group
        
        logging.info('>'*30)
        logging.info('action groups -> pixel index:')
        logging.info(f'\n{pformat(self.group2idx)}\n')
        logging.info('pixel index -> (action groups, action index):')
        logging.info(f'\n{pformat(self.idx2group)}\n')
        logging.info('<'*30)
        
    def predict(self, state, mask, return_prob=False):
        prob = self.model.execute(self.model.prob,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.m: mask})
        if return_prob: return prob
        return np.argmax(prob, axis=-1)
    
    @property
    def auxiliary_shape(self):
        H, W, C = self.hps.image_shape
        K = self.hps.num_classes
        return [H,W,K+C*4]
    
    def get_auxiliary(self, state, mask):
        B = mask.shape[0]
        H, W, C = self.hps.image_shape
        K = self.hps.num_classes
        N = self.hps.num_samples
        
        # probability
        prob = self.model.execute(self.model.prob,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.m: mask})
        prob = np.repeat(np.repeat(np.reshape(prob, [B,1,1,K]), H, axis=1), W, axis=2)
        
        # sample p(x_u | x_o) and p(x_u | x_o, pred)
        state_extended = np.repeat(np.expand_dims(state, axis=1), N, axis=1)
        state_extended = np.reshape(state_extended, [B*N,H,W,C])
        mask_extended = np.repeat(np.expand_dims(mask, axis=1), N, axis=1)
        mask_extended = np.reshape(mask_extended, [B*N,H,W,C])
        sam, pred_sam = self.model.execute([self.model.sam, self.model.pred_sam],
                    feed_dict={self.model.x: state_extended,
                               self.model.b: mask_extended,
                               self.model.m: np.ones_like(mask_extended)})
        sam = np.reshape(sam, [B,N,H,W,C]) / 255.
        pred_sam = np.reshape(pred_sam, [B,N,H,W,C]) / 255.
        
        sam_mean = np.mean(sam, axis=1)
        sam_std = np.std(sam, axis=1)
        pred_sam_mean = np.mean(pred_sam, axis=1)
        pred_sam_std = np.std(pred_sam, axis=1)
        
        # auxiliary
        aux = np.concatenate([prob, sam_mean, sam_std, pred_sam_mean, pred_sam_std], axis=-1)
        
        return aux
        
    def get_availability(self, state, mask):
        '''1: available    0: occupied'''
        B, H, W, C = mask.shape
        avail_action = np.ones([mask.shape[0], self.num_groups, self.group_size], dtype=np.bool)
        for i, m in enumerate(mask):
            m = m[...,0].reshape(H*W)
            indexes = np.where(m)[0]
            for idx in indexes:
                avail_action[i, self.idx2group[idx][0], self.idx2group[idx][1]] = 0
        avail_group = np.sum(~avail_action, axis=2) != self.group_size
        
        return avail_group, avail_action
        
    def get_action(self, group, action):
        pixel_idx = []
        for g, a in zip(group, action):
            pixel_idx.append(self.group2idx[g][a])
        return np.array(pixel_idx, dtype=np.int32)
        
    def get_reward(self, state, mask, action, next_state, next_mask, done, target):
        # information gain
        pre_ent = self.model.execute(self.model.ent,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.m: mask})
        post_ent = self.model.execute(self.model.ent,
                    feed_dict={self.model.x: next_state,
                               self.model.b: next_mask,
                               self.model.m: next_mask})
        cmi = pre_ent - post_ent
        
        if not np.all(done): return cmi
        
        # prediction reward
        xent = self.model.execute(self.model.xent,
                    feed_dict={self.model.x: next_state,
                               self.model.b: next_mask,
                               self.model.m: next_mask,
                               self.model.y: target})
        
        return cmi - xent
        
class FairModelWrapper(object):
    def __init__(self, hps):
        self.hps = hps
        H, W, C = hps.image_shape
        
        model_cfg = HParams(hps.model_cfg_file)
        self.model = get_model(model_cfg)
        self.model.load()
        
        planner_cfg = HParams(hps.planner_cfg_file)
        self.num_groups = planner_cfg.num_clusters
        self.group_size = H * W // self.num_groups

        with open(f'{planner_cfg.exp_dir}/results.pkl', 'rb') as f:
            labels = pickle.load(f)['labels']
        group2idx = defaultdict(list)
        idx2group = dict()
        group_count = [0]*self.num_groups
        for i, lab in enumerate(labels):
            lab = int(lab)
            group2idx[lab].append(i)
            idx2group[i] = (lab, group_count[lab])
            group_count[lab] = group_count[lab] + 1
        self.group2idx = group2idx
        self.idx2group = idx2group

        logging.info('>'*30)
        logging.info('action groups -> pixel index:')
        logging.info(f'\n{pformat(self.group2idx)}\n')
        logging.info('pixel index -> (action groups, action index):')
        logging.info(f'\n{pformat(self.idx2group)}\n')
        logging.info('<'*30)

    def predict(self, state, mask, return_samples=False):
        if return_samples:
            B = mask.shape[0]
            H, W, C = self.hps.image_shape
            N = self.hps.num_samples
            state_extended = np.repeat(np.expand_dims(state, axis=1), N, axis=1)
            state_extended = np.reshape(state_extended, [B*N,H,W,C])
            mask_extended = np.repeat(np.expand_dims(mask, axis=1), N, axis=1)
            mask_extended = np.reshape(mask_extended, [B*N,H,W,C])
            sam = self.model.execute(self.model.sam,
                    feed_dict={self.model.x: state_extended,
                               self.model.b: mask_extended,
                               self.model.m: np.ones_like(mask_extended)})
            sam = np.reshape(sam, [B,N,H,W,C])

            return sam
        else:
            mean = self.model.execute(self.model.mean,
                    feed_dict={self.model.x: state,
                               self.model.b: mask,
                               self.model.m: np.ones_like(mask)})

            return mean
    
    @property
    def auxiliary_shape(self):
        H, W, C = self.hps.image_shape
        return [H,W,C*2]

    def get_auxiliary(self, state, mask):
        B = mask.shape[0]
        H, W, C = self.hps.image_shape
        N = self.hps.num_samples
        state_extended = np.repeat(np.expand_dims(state, axis=1), N, axis=1)
        state_extended = np.reshape(state_extended, [B*N,H,W,C])
        mask_extended = np.repeat(np.expand_dims(mask, axis=1), N, axis=1)
        mask_extended = np.reshape(mask_extended, [B*N,H,W,C])
        sam = self.model.execute(self.model.sam,
                feed_dict={self.model.x: state_extended,
                           self.model.b: mask_extended,
                           self.model.m: np.ones_like(mask_extended)})
        sam = np.reshape(sam, [B,N,H,W,C]) / 255.

        sam_mean = np.mean(sam, axis=1)
        sam_std = np.std(sam, axis=1)

        aux = np.concatenate([sam_mean, sam_std], axis=-1)

        return aux

    def get_availability(self, state, mask):
        '''1: available    0: occupied'''
        B, H, W, C = mask.shape
        avail_action = np.ones([mask.shape[0], self.num_groups, self.group_size], dtype=np.bool)
        for i, m in enumerate(mask):
            m = m[...,0].reshape(H*W)
            indexes = np.where(m)[0]
            for idx in indexes:
                avail_action[i, self.idx2group[idx][0], self.idx2group[idx][1]] = 0
        avail_group = np.sum(~avail_action, axis=2) != self.group_size
        
        return avail_group, avail_action

    def get_action(self, group, action):
        pixel_idx = []
        for g, a in zip(group, action):
            pixel_idx.append(self.group2idx[g][a])
        return np.array(pixel_idx, dtype=np.int32)

    def get_reward(self, state, mask, action, next_state, next_mask, done, target):
        # information gain
        pre_bpd = self.model.execute(self.model.bits_per_dim,
                    feed_dict={self.model.x: target,
                               self.model.b: mask,
                               self.model.m: np.ones_like(mask)})
        post_bpd = self.model.execute(self.model.bits_per_dim,
                    feed_dict={self.model.x: target,
                               self.model.b: next_mask,
                               self.model.m: np.ones_like(next_mask)})
        gain = pre_bpd - post_bpd

        if not np.all(done): return gain

        # prediction reward
        mse = self.model.execute(self.model.mse_per_dim,
                    feed_dict={self.model.x: target,
                               self.model.b: next_mask,
                               self.model.m: np.ones_like(next_mask)})
        
        return gain - mse