import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

from collections import OrderedDict
from .rep_sims import CKA, DifferentiableRSA
from .wrapper import TransformerDecoderLayerWrapper, TransEncWrapper

def hook_fn(module, input, output, layer_outputs, layer_name):
    layer_outputs[layer_name] = output

def register_hooks(model, layer_outputs):
    hooks = []
    def register_recursive(module, name_prefix=''):
        for name, child in module.named_children():
            full_name = f'{name_prefix}.{name}' if name_prefix else name
            # We access feature map information across several pytorch layers to ensure maximum coverage of activations for supervision.
            if isinstance(child, (nn.Conv2d, nn.Linear, nn.AdaptiveAvgPool2d, nn.LSTM, nn.RNN, nn.TransformerDecoderLayer, nn.LayerNorm, nn.MultiheadAttention)):
                hooks.append(child.register_forward_hook(
                    lambda m, i, o, full_name=full_name: hook_fn(m, i, o, layer_outputs, full_name)
                ))
            elif isinstance(child, TransformerDecoderLayerWrapper):
                register_recursive(child.layer, full_name)
            elif isinstance(child, TransEncWrapper):
                register_recursive(child.layer, full_name)
            else:
                register_recursive(child, full_name)
    register_recursive(model)
    return hooks

def get_layer_outputs(model, inputs, eval = False):
    layer_outputs = OrderedDict()
    hooks = register_hooks(model, layer_outputs)
    if eval:
        model = model.eval()
        with torch.no_grad():
            model(inputs)
    else:
        model(inputs)
    for hook in hooks:
        hook.remove()
    return layer_outputs

def layerwise_sim(train_model, target_model, rep_sim, inputs, target_inputs, device):
    cka = CKA(device)
    pro = Procrustes(device)
    diff_rsa = DifferentiableRSA(device)
    pretrained_outputs = get_layer_outputs(target_model, target_inputs, eval = True)
    training_outputs = get_layer_outputs(train_model, inputs)
    sim_scores = {}
    for layer_name in pretrained_outputs:
        if layer_name in training_outputs:
            pretrained_output = pretrained_outputs[layer_name].view(inputs.size(0), -1)
            training_output = training_outputs[layer_name].view(inputs.size(0), -1)
            if rep_sim == 'CKA':
                sim = 1 - cka.linear_CKA(training_output.to(torch.float32), pretrained_output.to(torch.float32))
            elif rep_sim == 'RSA':
                sim = 1 - diff_rsa.rsa(training_output.to(torch.float32), pretrained_output.to(torch.float32))
            else:
                raise NotImplementedError()
            sim_scores[layer_name] = sim
    return sim_scores

def map_layers():
    rn18_rn50_mapping = {
        'conv1': 'conv1',
        'layer1.0.conv1': 'layer1.0.conv1',
        'layer1.0.conv2': 'layer1.0.conv3',
        'layer1.1.conv1': 'layer1.2.conv1',
        'layer1.1.conv2': 'layer1.2.conv3',
        'layer2.0.conv1': 'layer2.0.conv1',
        'layer2.0.conv2': 'layer2.0.conv3',
        'layer2.0.downsample.0': 'layer2.0.downsample.0',
        'layer2.1.conv1': 'layer2.2.conv1',
        'layer2.1.conv2': 'layer2.2.conv3',
        'layer3.0.conv1': 'layer3.0.conv1',
        'layer3.0.conv2': 'layer3.0.conv3',
        'layer3.0.downsample.0': 'layer3.0.downsample.0',
        'layer3.1.conv1': 'layer3.3.conv1',
        'layer3.1.conv2': 'layer3.3.conv3',
        'layer4.0.conv1': 'layer4.0.conv1',
        'layer4.0.conv2': 'layer4.0.conv3',
        'layer4.0.downsample.0': 'layer4.0.downsample.0',
        'layer4.1.conv1': 'layer4.2.conv1',
        'layer4.1.conv2': 'layer4.2.conv3',
    }

    return rn18_rn50_mapping

def map_rn_mlp():
    rn18_mlp_mapping = {
        'conv1': 'initial_layer.0',
        'layer1.0.conv1': 'intermediate_layers.1.0',
        'layer1.0.conv2': 'intermediate_layers.3.0',
        'layer1.1.conv1': 'intermediate_layers.5.0',
        'layer1.1.conv2': 'intermediate_layers.7.0',
        'layer2.0.conv1': 'intermediate_layers.9.0',
        'layer2.0.conv2': 'intermediate_layers.11.0',
        'layer2.0.downsample.0': 'intermediate_layers.13.0',
        'layer2.1.conv1': 'intermediate_layers.15.0',
        'layer2.1.conv2': 'intermediate_layers.17.0',
        'layer3.0.conv1': 'intermediate_layers.20.0',
        'layer3.0.conv2': 'intermediate_layers.22.0',
        'layer3.0.downsample.0': 'intermediate_layers.24.0',
        'layer3.1.conv1': 'intermediate_layers.26.0',
        'layer3.1.conv2': 'intermediate_layers.28.0',
        'layer4.0.conv1': 'intermediate_layers.31.0',
        'layer4.0.conv2': 'intermediate_layers.33.0',
        'layer4.0.downsample.0': 'intermediate_layers.35.0',
        'layer4.1.conv1': 'intermediate_layers.37.0',
        'layer4.1.conv2': 'intermediate_layers.39.0',
        'avgpool': 'intermediate_layers.42.0',
        'fc': 'output_layer.0'
    }
    
    return rn18_mlp_mapping

def torchvision_fe(model, inputs, device):
    layers = torchvision.models.feature_extraction.get_graph_node_names(model)[0]
    extract_layers = [l for l in layers if ('mlp' in l) or ('self_attention' in l)] + ['getitem_5']
    feature_extractor = torchvision.models.feature_extraction.create_feature_extractor(model, return_nodes = extract_layers)
    feature_extractor = feature_extractor.to(device)
    with torch.no_grad():
        output = feature_extractor(inputs)
    return output

def layer_supervision(target_model_layers, student_model_layers):
    '''
    Designs layer mapping between guide and student layers. Spreads the layers of the guide network across the layers of the student.
    '''
    source_count = len(target_model_layers)
    target_count = len(student_model_layers)
    step = (target_count - 1) / (source_count - 1) if source_count > 1 else 1

    mapping = {}
    for i, source_layer in enumerate(target_model_layers):
        target_index = min(round(i * step), target_count - 1)
        mapping[source_layer] = student_model_layers[target_index]
    return mapping

def layermap_sim(train_model, target_model, student_model, rep_sim, inputs, target_inputs, device, torchvision_extract = False):
    '''
    Computes layerwise representational dissimilarity between several layer pairs of the two networks.
    '''
    cka = CKA(device)
    diff_rsa = DifferentiableRSA(device)
    if not torchvision_extract:
        pretrained_outputs = get_layer_outputs(target_model, target_inputs, eval = True)
    else:
        pretrained_outputs = torchvision_fe(target_model, inputs, device)
    training_outputs = get_layer_outputs(train_model, inputs)
    if student_model == 'ResNet-50':
        model_mapping = map_layers()
    else:
        teacher_layers = list(pretrained_outputs.keys())
        student_layers = list(training_outputs.keys())
        if len(teacher_layers) <= len(student_layers):
            model_mapping = layer_supervision(teacher_layers, student_layers)
        else:
            #NOTE: I am trying to add multiple levels of supervision in this case. If this works better, I'll keep it.
            #Otherwise, I'll switch back
            model_mapping = layer_supervision(teacher_layers, student_layers)
            # model_mapping = {v : k for k, v in model_mapping.items()}
    sim_scores = {}
    for layer in model_mapping:
        assert layer in pretrained_outputs, f'Layer {layer} is not in target network {pretrained_outputs.keys()}'
        tr_layer = model_mapping[layer]
        assert tr_layer in training_outputs, f'Layer {layer} is not in {student_model} {training_outputs.keys()}'

        pretrained_output = pretrained_outputs[layer]
        if isinstance(pretrained_output, tuple):
            pretrained_output = pretrained_output[0]
        if isinstance(pretrained_output, nn.utils.rnn.PackedSequence):
            pretrained_output, _ = nn.utils.rnn.pad_packed_sequence(pretrained_output, batch_first = True)

        training_output = training_outputs[tr_layer]
        if isinstance(training_output, tuple):
            training_output = training_output[0]
        if isinstance(training_output, nn.utils.rnn.PackedSequence):
            training_output, _ = nn.utils.rnn.pad_packed_sequence(training_output, batch_first = True)

        if rep_sim == 'CKA':
            pretrained_output = pretrained_output.contiguous().view(inputs.size(0), -1)
            training_output = training_output.contiguous().view(inputs.size(0), -1)
            sim = 1 - cka.linear_CKA(training_output.to(torch.float32), pretrained_output.to(torch.float32))
        elif rep_sim == 'RSA':
            pretrained_output = pretrained_output.contiguous().view(inputs.size(0), -1)
            training_output = training_output.contiguous().view(inputs.size(0), -1)
            sim = 1 - diff_rsa.rsa(training_output.to(torch.float32), pretrained_output.to(torch.float32))
        else:
            raise NotImplementedError()
        sim_scores[tr_layer] = sim
    del pretrained_outputs
    del training_outputs
    return sim_scores