import torch

from .rep_sims import CKA, Procrustes, train_ridge, fit_ridge, CCA, DifferentiableRSA
from .layer_sim import *
from .wrapper import LSTMWrapper, TransformerDecoderLayerWrapper, SeqClassWrapper, TransEncWrapper

def create_rand_inputs(inputs):
    if inputs.dtype == torch.float32:
        mean = 0.0
        std = 0.1
        new_inputs = torch.randn(*inputs.shape)
        noise = torch.randn(*inputs.shape) * std + mean
        noisy_batch = new_inputs + noise
        noisy_batch = torch.clamp(noisy_batch, 0.0, 1.0)
    elif inputs.dtype == torch.long:
        max_val = torch.max(inputs).item()
        noisy_batch = torch.randint(0, max_val + 1, inputs.shape)
    else:
        raise ValueError(f'Data type {inputs.dtype} is not compatible...')
    return noisy_batch

def wrap_and_flatten_model(model, parity = False, lengths = None, mask = None):
    new_modules = []
    
    for module in model.children():
        if isinstance(module, nn.ModuleList) or isinstance(module, nn.Sequential):
            for sub_module in module:
                if isinstance(sub_module, nn.LSTM) or isinstance(sub_module, nn.RNN):
                    if not parity:
                        new_modules.append(LSTMWrapper(sub_module))
                    else:
                        new_modules.append(SeqClassWrapper(sub_module, lengths))
                elif isinstance(sub_module, nn.TransformerDecoderLayer):
                    new_modules.append(TransformerDecoderLayerWrapper(sub_module))
                elif isinstance(sub_module, nn.TransformerEncoderLayer):
                    assert parity
                    new_modules.append(TransEncWrapper(sub_module, mask))
                else:
                    new_modules.append(sub_module)
        elif isinstance(module, nn.LSTM) or isinstance(module, nn.RNN):
            if not parity:
                new_modules.append(LSTMWrapper(module))
            else:
                new_modules.append(SeqClassWrapper(module, lengths))        
        elif isinstance(module, nn.TransformerDecoderLayer):
            new_modules.append(TransformerDecoderLayerWrapper(module))
        elif isinstance(module, nn.TransformerEncoderLayer):
            assert parity
            new_modules.append(TransEncWrapper(module, mask))
        else:
            new_modules.append(module)

    if isinstance(new_modules[-2], nn.Linear) and new_modules[-2].out_features == 50257:
        return nn.Sequential(*new_modules[:-2])
    else:
        return nn.Sequential(*new_modules[:-1])

def rep_similarity_loss(exp_name, train_model, target_model, rep_sim, inputs, device, 
                        layerwise, student_model = 'ResNet-50', hf_base = False, one_layer = False, 
                        parity = False, lengths = None, use_noise = False, torchvision_extract = False, 
                        token_sim = False, batch_idx = None):
    if hf_base:
        with torch.no_grad():
            target_features = target_model(inputs, output_hidden_states = True).last_hidden_state

    mask = None
    if use_noise:
        target_batch = create_rand_inputs(inputs)
        target_batch = target_batch.to(device)
    else:
        target_batch = inputs
    if parity:
        mask = (inputs == 2)
    target_model_fe = wrap_and_flatten_model(target_model, parity, lengths, mask)
    train_model_fe = wrap_and_flatten_model(train_model, parity, lengths, mask)
    if one_layer:
        batch_inputs = train_model_fe(target_batch).squeeze()
        with torch.no_grad():
            target_features = target_model_fe(target_batch).squeeze()

    if rep_sim == 'CKA':
        if one_layer:
            cka = CKA(device)
            sim = 1 - cka.linear_CKA(batch_inputs.to(torch.float32), target_features.to(torch.float32))
        else:
            if layerwise:
                sims = torch.stack(list(layerwise_sim(train_model_fe, target_model_fe, rep_sim, inputs, target_batch, device).values()))
            else:
                sims = torch.stack(list(layermap_sim(train_model_fe, target_model_fe if not torchvision_extract else target_model, 
                                                     student_model, rep_sim, inputs, target_batch, device, torchvision_extract = torchvision_extract,
                                                     token_sim = token_sim).values()))
            sim = torch.sum(sims)
    else:
        raise NotImplementedError
    return sim