import torch
import torch.nn as nn
from tqdm import tqdm
from easydict import EasyDict as edict

from src.data_loader.utils import get_data, get_train_val_split
from src.data_loader.data_set import Data_Set
import torchvision.models as models
from torch.utils.data import DataLoader

from alignment_uniformity import calc_alignment, calc_uniformity


from src.constants import (
    COMET_KWARGS,
    HYBRID2_CONFIG,
    BASE_DIR,
    TRAINING_CONFIG_PATH,
)
from src.utils import get_console_logger, read_json

class ResNetModelWithoutFC(nn.Module):
    """Adapted ResNet model with layers renamed.
    """

    def __init__(self, resnet_name):
        super().__init__()
        model_function = self.get_resnet(resnet_name)
        self.model = model_function(norm_layer =nn.BatchNorm2d)
        self.features = nn.Sequential(
            self.model.conv1,
            self.model.bn1,
            self.model.relu,
            self.model.maxpool,
            self.model.layer1,
            self.model.layer2,
            self.model.layer3,
            self.model.layer4,
            nn.AdaptiveAvgPool2d(output_size=(1,1)),
        )
    
    def get_resnet(self, resnet_name):
        if "resnet18" == resnet_name:
            return models.resnet18
        elif "resnet34" == resnet_name:
            return models.resnet34
        elif "resnet50" == resnet_name:
            return models.resnet50
        elif "resnet101" == resnet_name:
            return models.resnet101
        elif "resnet152" == resnet_name:
            return models.resnet152
        else:
            raise NotImplementedError

    def import_pretrained(self, path_to_peclr_weights):
        peclr_weights = torch.load(path_to_peclr_weights, map_location=torch.device("cpu"))
        peclr_state_dict = peclr_weights["state_dict"]
        resnet_state_dict_list = list(self.model.state_dict().items())
        peclr_state_dict_list = [
            (key, peclr_state_dict[key])
            for key in peclr_state_dict
            if "features" in key
        ]
        last_feature_idx = len(peclr_state_dict_list)
        own_state = self.model.state_dict()
        for idx in tqdm(range(last_feature_idx)):
            if (
                resnet_state_dict_list[idx][0].split(".")[-1]
                != peclr_state_dict_list[idx][0].split(".")[-1]
            ):
                print("PeCLR layers don't match with Resnet layer ")
                break
            name = resnet_state_dict_list[idx][0]
            param = peclr_state_dict_list[idx][1]
            try:
                own_state[name].copy_(param)
            except Exception as e:
                print("The models are not compatible!")
                print(f"Exception :{e}")
                break
    
    def forward(self, x):
        """Forward method, return embeddings when the mode is pretraining.
        and return 2.5D keypoints, None and scale otherwise.
        """
        z = self.features(x)
        z = z.flatten(start_dim=1)
        return z

def make_batch(samples):
    transformed_image1 = torch.stack([i['transformed_image1'] for i in samples])
    return {'transformed_image1': transformed_image1}


def calc(model_path):
    BATCH_SIZE = 512
    NUM_WORKERS = 4
    device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")

    train_param = edict(read_json(TRAINING_CONFIG_PATH))
    train_param['train_ratio'] = 0.999999999
    
    # data preperation
    data = get_data(
        Data_Set, edict(train_param), sources=["freihand"], experiment_type="supervised" , split='test'
    )
    data_loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True ,drop_last=True)

    # get model
    model = ResNetModelWithoutFC('resnet152')
    model.import_pretrained(model_path)
    model = model.to(device)
    model.eval()

    alignment_list = []
    uniformity_list = []
    with torch.no_grad():
        for batch in tqdm(data_loader, total = len(data_loader)):
            image = batch['image'].to(device)
            label = batch['joints'].to(device)
            representation = model(image)
            alignment = calc_alignment(representation, label)
            uniformity = calc_uniformity(representation)
            alignment_list.append(alignment)
            uniformity_list.append(uniformity)
    alignment = - sum(alignment_list) / len(alignment_list)
    uniformity = sum(uniformity_list) / len(uniformity_list)
    return alignment, uniformity

def main():
#    model_list = ['simclr.ckpt',
#        'regcon_linear_in.ckpt',
#        'regcon_linear_out.ckpt',
#        'regcon_tanh_in.ckpt',
#        'regcon_tanh_out.ckpt',
#        'regcon_tanh0.5_in.ckpt',
#        'regcon_tanh0.5_out.ckpt',
#        'regcon_tanh2_in.ckpt',
#        'regcon_tanh2_out.ckpt',]
    model_list = ['regcon_tanh3_out.ckpt']
    alignment_uniformity = [calc(i) for i in model_list]

    for idx, model_name in enumerate(model_list):
        print(f'model_name: {model_name:30s}, alignment: {alignment_uniformity[idx][0]:10f}, uniformity: {alignment_uniformity[idx][1]:10f}')
    
if __name__ == "__main__":
    main()