'''
    This script automates the steps necessary to execute the Information Analysis on Deep Nets
'''

import torch
import pickle
import numpy as np
import torch.nn as nn

from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageNet

from utils.recorder import Traced
from utils.miscellaneous import get_info
from utils.receptivefield import ForwardRF
from utils.information import Entropy, bin_data_equipop
from utils.information import bin_data_equisize
from utils.information import MutualInformation
from utils.information import Entropy
from utils.information import H_XY, I_XY, H

from tqdm import tqdm
from tqdm import trange

from typing import List, Tuple, Union, Callable

from functools import partial
from multiprocessing import Pool

class InfoNet:
    '''
        This class represents an Information Metric experiment performed on a Deep Network
    '''

    imagenet_root = 'path_to_ImageNet'

    def __init__(self,
        net : torch.nn.Module,
        img_size : int = 224,
        ) -> None:
        
        # * Dataset loading
        self.img_size = img_size
        img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        self.pipeline = transforms.Compose([
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean = img_mean, std = img_std)
        ])

        # Select the torch running device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # * Store the reference to the deep network to test
        self.net = net 

        self.recorder = None
        self.pop_size = None
        self.n_batch  = None

        self.MI_r, self.MI_g, self.MI_b = None, None, None
        self.H_r,  self.H_g,  self.H_b  = None, None, None

    def record(self, 
        n_batch : int = 15,
        batch_size : int = 128,
        pop_size : int = 400,
        rec_imgs : List[torch.FloatTensor] = None,
        units : List[int] = None,
        excluded_layers = (nn.Dropout, nn.BatchNorm2d, nn.AdaptiveAvgPool2d),
        **kwargs
        ) -> List[torch.FloatTensor]:

        self.n_batch = n_batch
        
        if units is None:
            self.pop_size = pop_size
            # Get network layer shapes for the selected input
            self.info = get_info(self.net, (1, 3, self.img_size, self.img_size), exclude = excluded_layers)

            # Use network shapes to select a random population of fixed size
            units = [np.random.randint(0, shape, size = (pop_size, len(shape))) for shape in self.info.shapes]

        else:
            self.pop_size = units[0].shape[0]

        # * Send network to working device and track its activations
        net = Traced(self.net.to(self.device), units = units, exclude = excluded_layers)

        imgs = []

        # * Activate the traced network so to store its activations
        net.clean()
        net.trace()
        if rec_imgs is None:
            # Here we load the ImageNet dataset
            self.batch_size = batch_size

            # Prepare the dataset loader
            imagenet = ImageNet(self.imagenet_root, split = 'val', transform = self.pipeline)
            self.loader = DataLoader(imagenet, batch_size = self.batch_size, shuffle = True)

            for _ in trange(n_batch, desc = 'Recording InfoNets', leave = False):
                with torch.no_grad():
                    imgs_batch, _ = next(iter(self.loader))

                    _ = net(imgs_batch.to(self.device))

                    imgs += [imgs_batch.detach().cpu().numpy()]
        else:
            for imgs_batch, _ in tqdm(rec_imgs, desc = 'Recording InfoNets', leave = False):
                with torch.no_grad():
                    _ = net(imgs_batch.to(self.device))

            imgs = rec_imgs
            n_batch = len(rec_imgs)
            self.batch_size = len(imgs_batch)

        net.untrace()

        self.recorder = net.recorder

        # Stack the features for each layer for later use
        self.features = {hook : np.array(feat).reshape(n_batch * self.batch_size, -1) for hook, feat in self.recorder.features.items()}

        # Return the image that were used to record the activity
        return imgs

    def measure_RF(self, 
        RF : Callable = ForwardRF,
        keys : List[str] = None,
        units : List[np.ndarray] = None,
        excluded_layers = (nn.Dropout, nn.BatchNorm2d, nn.AdaptiveAvgPool2d)
        ) -> dict:
        if self.recorder is None and keys is None:
            raise ValueError('Cannot measure RF without a key list or before activity is recorded.')
        
        if self.recorder is None and units is None:
            raise ValueError('Cannot measure RF without a unit list or before activity is recorded.')
            
        input_shape = (3, self.img_size, self.img_size)

        layers = list(self.recorder.keys  if keys  is None else keys)
        runits = list(self.recorder.units if units is None else units)  

        # Construct the Receptive Field object with handles for alexnet layers
        _rf = RF(layers, input_shape)

        # Compute the model RF parameters
        _ = _rf(self.net, self.device, exclude = excluded_layers)

        return {layer : _rf.get_unit_rf(layer, unit_pos = units[:, 1:]) for layer, units in zip(layers, runits)}

    def measure_MI(self,
        data : dict, 
        nbins : int = 20,  
        bin_strategy : str = 'equisize',
        **kwargs
        ) -> Tuple[List[float], List[float], List[float]]:

        valid_bs = ('equisize', 'equipop')

        
        if bin_strategy == 'equisize':
            bin_f = lambda x : bin_data_equisize(x, nbins = nbins)
        elif bin_strategy:
            bin_f = lambda x : bin_data_equipop (x, nbins = nbins)
        else:
            raise ValueError(f'Unknown binning strategy {bin_strategy}. Use one in {valid_bs}')
        
        # Store the bin strategy used
        self.bin_strategy = bin_strategy

        # * Bin the provided external quantity
        self.binR = {layer : np.apply_along_axis(bin_f, 1, X[..., 0]) for layer, X in data.items()}
        self.binG = {layer : np.apply_along_axis(bin_f, 1, X[..., 1]) for layer, X in data.items()}
        self.binB = {layer : np.apply_along_axis(bin_f, 1, X[..., 2]) for layer, X in data.items()}

        # * Bin the recorded unit activations
        self.binF = {layer : np.apply_along_axis(bin_f, 0, feat).T for layer, feat in self.features.items()}

        # * Compute the Mutual Information for each channel separately and for each layer
        _I = MutualInformation()

        # First we compute the equisized case
        keys = self.recorder.keys
        self.MI_r = {l : _I(self.binR[l], self.binF[l]) for l in tqdm(keys, desc = 'MI [R]', leave = False)}
        self.MI_g = {l : _I(self.binG[l], self.binF[l]) for l in tqdm(keys, desc = 'MI [G]', leave = False)}
        self.MI_b = {l : _I(self.binB[l], self.binF[l]) for l in tqdm(keys, desc = 'MI [B]', leave = False)}

        return self.MI_r, self.MI_g, self.MI_b

    def measure_H(self,
        data : dict, 
        nbins : int = 20,  
        bin_strategy : str = 'equisize',
        ent_args : List[str] = ('HR', 'HRS', 'HshRS', 'HiR'),
        bias : str = 'pt',
        **kwargs
        ) -> Tuple[List[float], List[float], List[float]]:

        valid_bs = ('equisize', 'equipop')

        
        if bin_strategy == 'equisize':
            bin_f = lambda x : bin_data_equisize(x, nbins = nbins)
        elif bin_strategy:
            bin_f = lambda x : bin_data_equipop (x, nbins = nbins)
        else:
            raise ValueError(f'Unknown binning strategy {bin_strategy}. Use one in {valid_bs}')
        
        # Store the bin strategy used
        self.bin_strategy = bin_strategy

        # * Bin the provided external quantity
        self.binR = {layer : np.apply_along_axis(bin_f, 1, X[..., 0]) for layer, X in data.items()}
        self.binG = {layer : np.apply_along_axis(bin_f, 1, X[..., 1]) for layer, X in data.items()}
        self.binB = {layer : np.apply_along_axis(bin_f, 1, X[..., 2]) for layer, X in data.items()}

        # * Bin the recorded unit activations
        self.binF = {layer : np.apply_along_axis(bin_f, 0, feat).T for layer, feat in self.features.items()}

        # * Compute the Mutual Information for each channel separately and for each layer
        _H = Entropy()

        # First we compute the equisized case
        keys = self.recorder.keys
        self.H_r = {l : _H(self.binR[l], self.binF[l], args = ent_args, bias = bias) for l in tqdm(keys, desc = 'H [R]', leave = False)}
        self.H_g = {l : _H(self.binG[l], self.binF[l], args = ent_args, bias = bias) for l in tqdm(keys, desc = 'H [G]', leave = False)}
        self.H_b = {l : _H(self.binB[l], self.binF[l], args = ent_args, bias = bias) for l in tqdm(keys, desc = 'H [B]', leave = False)}

        return self.H_r, self.H_g, self.H_b

    def measure(self,
        data : dict, 
        nbins : int = 20,  
        bin_strategy : str = 'equisize',
        bias : str = None,
        norm : str = None,
        **kwargs
        ) -> Tuple[dict, dict, dict]:

        valid_bs = ('equisize', 'equipop')

        if bin_strategy == 'equisize':
            bin_f = lambda x : bin_data_equisize(x, nbins = nbins)
        elif bin_strategy:
            bin_f = lambda x : bin_data_equipop (x, nbins = nbins)
        else:
            raise ValueError(f'Unknown binning strategy {bin_strategy}. Use one in {valid_bs}')
        
        # Store the bin strategy used
        self.bin_strategy = bin_strategy

        # * Bin the provided external quantity
        self.binR = {layer : np.apply_along_axis(bin_f, 1, X[..., 0]) for layer, X in data.items()}
        self.binG = {layer : np.apply_along_axis(bin_f, 1, X[..., 1]) for layer, X in data.items()}
        self.binB = {layer : np.apply_along_axis(bin_f, 1, X[..., 2]) for layer, X in data.items()}

        # * Bin the recorded unit activations
        self.binF = {layer : np.apply_along_axis(bin_f, 0, feat).T for layer, feat in self.features.items()}

        # * Compute the relevant Information-Theoretic quantities
        # * for each channel separately and for each layer
        __H   = partial(H,    bias = bias)
        __Hxy = partial(H_XY, bias = bias, kind = 'x|y')
        __Hyx = partial(H_XY, bias = bias, kind = 'y|x')
        __HXY = partial(H_XY, bias = bias, kind = 'x,y')
        __Ixy = partial(I_XY, bias = bias, norm = norm, order = 'x,y')

        def _infoq(
                Xs : np.ndarray, 
                Ys : np.ndarray, 
                P : Pool, 
                chunk : int = 10
            ) -> List[Tuple[float, float, float, float]]:
            info = {}

            # Apply the computation of the various entropy along the unit axis
            info['Hx'] = list(P.imap(__H, Xs, chunk))
            info['Hy'] = list(P.imap(__H, Ys, chunk))

            info['Hx|y'] = list(P.starmap(__Hxy, zip(Xs, Ys), chunk))
            info['Hy|x'] = list(P.starmap(__Hyx, zip(Xs, Ys), chunk))
            info['Hx,y'] = list(P.starmap(__HXY, zip(Xs, Ys), chunk))
            info['Ix,y'] = list(P.starmap(__Ixy, zip(Xs, Ys), chunk))

            return info

        keys = self.recorder.keys

        with Pool() as P:
            self.H_r = {l : _infoq(self.binR[l], self.binF[l], P) for l in tqdm(keys, desc = 'H [R]', leave = False)}
            self.H_g = {l : _infoq(self.binG[l], self.binF[l], P) for l in tqdm(keys, desc = 'H [G]', leave = False)}
            self.H_b = {l : _infoq(self.binB[l], self.binF[l], P) for l in tqdm(keys, desc = 'H [B]', leave = False)}

        return self.H_r, self.H_g, self.H_b


    def run(self, Y : dict, **kwargs):
        '''
            Run the experiment with optionally provided images and unit indices.
            Y is expected to be a dictionary of {layers : data} format.
        '''

        # Record the activity of the network
        self.record(**kwargs)

        # Measure the Mutual Information
        MI = self.measure_MI(Y, **kwargs)

        if 'savepath' in kwargs:
            self.dump(kwargs['savepath'])

        return MI


    def dump(self, savepath : str, desc = 'Results of InfoNets'):
        data = {
            'MI' : {
                'R' : self.MI_r,
                'G' : self.MI_g,
                'B' : self.MI_b
            },

            'H' : {
                'R' : self.H_r,
                'G' : self.H_g,
                'B' : self.H_b
            },

            'bin_strategy' : self.bin_strategy,

            'n_batch' : self.n_batch,
            'img_size' : self.img_size,
            'batch_size' : self.batch_size,
            'n_units' : self.pop_size,

            'layers' : self.recorder.keys,
            'network' : str(self.net),

            'desc' : desc
        }

        # Store the results of the Mutual Information
        with open(savepath, 'wb') as f:
            pickle.dump(data, f) 