import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from argparse import ArgumentParser
import numpy as np
import tqdm
import os
import pickle
from sklearn.model_selection import KFold
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import accuracy_score
import clip
import spacy
import json
import torch
import time
from PIL import Image
from collections import defaultdict
from torch import pca_lowrank
from numpy.linalg import svd
from utils import orthogonal_procrustes, calc_cosine_sim


def create_parser():
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True, choices=["mscoco", "flickr30k"], default=1, help="Fraction of data to use for training of the mapping")
    parser.add_argument('--vis-encoder', choices=['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14',
                                 'ViT-L/14@336px'], help="Vision encoder to be used")
    parser.add_argument('--paragraphs', action='store_true', help="whether to use entire paragraphs for alignment")
    parser.add_argument('--localized-narratives', action='store_true', help="whether to use localized narratives data for alignment")
    return parser.parse_args()


def main():
    options = create_parser()

    vis_encoder = options.vis_encoder.replace("/", "")
    train_imgs = pickle.load(open(f"data/{options.dataset}/imgs_train.pkl", "rb"))
    if options.paragraphs:
        train_caps = pickle.load(open(f'data/{options.dataset}/{vis_encoder}_train_paragraphs.pkl', 'rb'))
    elif options.localized_narratives:
        train_caps = pickle.load(open(f'data/{options.dataset}/{vis_encoder}_train_caps_localized_narratives.pkl', 'rb'))
    else:
        train_caps = pickle.load(open(f'data/{options.dataset}/{vis_encoder}_train_caps.pkl', 'rb'))
    captions = {key: embs for key, _, embs in train_caps}

    dataset = "mscoco" if options.dataset == 'coco' else options.dataset
    if options.fraction == 1.:
        subset_keys = list(train_imgs.keys())
    else:
        np.random.seed(101)
        subset_keys = np.random.choice(list(train_imgs.keys()), size=int(len(train_imgs) * options.fraction))

    print(f"Model: {vis_encoder}")
    # if not os.path.exists(f'./data/{fm_clean}_{lm}_{dataset}_{options.fraction}_embs.npy'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Constructing dataset...")

    model, preprocess = clip.load(encoder)
    model.cuda().eval()

    text_features = []
    train_cls = []
    for i, key in enumerate(subset_keys):
        filename = key.split("/")[-1]
        cap = captions.get(filename, None)
        if cap is None:
            cap = captions.get(filename.split('.')[0], None)
        if cap is None:
            cap = captions.get(filename.split('.')[0].split('_')[-1].lstrip('0'))
        if cap is None:
            continue

        text_features.append(cap)
        train_cls.append(np.repeat(i, len(cap), axis=0))

    image_features = []
    batch_size = 128
    with torch.no_grad():
        for i in tqdm.trange(0, len(subset_keys), batch_size):
            ids = subset_keys[i: i + batch_size]
            batch = torch.stack([preprocess(train_imgs[id]) for id in ids]).to(device)
            embeddings = model.encode_image(batch).cpu().numpy()
            image_features.append(embeddings)

    text_features = np.concatenate(text_features)
    train_cls = np.concatenate(train_cls)
    image_features = np.concatenate(image_features)

    # preprocess embedding spaces, length normalization plus mean centering
    image_features /= np.linalg.norm(image_features, ord=2, axis=-1, keepdims=True)
    image_features -= image_features.mean(0)
    text_features /= np.linalg.norm(text_features, ord=2, axis=-1, keepdims=True)
    text_features -= text_features.mean(0)

    src_embs = image_features[train_cls].astype(np.float32)
    tar_embs = text_features.astype(np.float32)
    times = []

    for train_method in ['linear_reg']:

        if train_method == 'procrustes':
            start = time.time()
            proj_mat = orthogonal_procrustes(src_embs, tar_embs)
            end = time.time()
        elif train_method == 'robust_procrustes':
            # Robust procrustes method from https://arxiv.org/abs/2205.11616
            start = time.time()
            eps = 1e-3
            m = 5
            proj_mat = orthogonal_procrustes(src_embs, tar_embs)
            for _ in range(m):
                weights = 1 / (np.linalg.norm(tar_embs - (src_embs @ proj_mat), ord=2, axis=-1) + eps)
                weights = (weights / np.max(weights))**.5
                weights = weights.reshape(-1, 1)
                proj_mat = orthogonal_procrustes(weights * src_embs, weights * tar_embs)
            end = time.time()
        elif train_method == 'linear_reg':
            start = time.time()
            if options.lowrank:
                u, s, v = lowrank_lstsq(src_embs, tar_embs, options.r)
                proj_mat = (u, v)
            else:
                model = LinearRegression(fit_intercept=False)
                model.fit(src_embs, tar_embs)
                proj_mat = model.coef_.T
            end = time.time()
        elif train_method == 'ridge_reg':
            start = time.time()
            model = Ridge(alpha=1., fit_intercept=False)
            model.fit(src_embs, tar_embs)
            end = time.time()
            proj_mat = model.coef_.T
        else:
            raise NotImplementedError(f'{train_method} - Training method not supported!!')

        print(f"Elapsed time: {end - start}")
        times.append(end-start)

    if options.paragraphs:
        np.save(os.path.join('./models', f'{vis_encoder}_{train_method}_{dataset}_{options.fraction}_retrieval_paragraphs'), proj_mat)
    elif options.localized_narratives:
        np.save(os.path.join('./models', f'{vis_encoder}_{train_method}_{dataset}_{options.fraction}_retrieval_localized_narratives'), proj_mat)
    else:
        if options.lowrank:
            np.save(os.path.join('./models', f'{vis_encoder}_{train_method}_{dataset}_{options.fraction}_retrieval_lowrank_{options.r}'), proj_mat)
        else:
            np.save(os.path.join('./models', f'{vis_encoder}_{train_method}_{dataset}_{options.fraction}_retrieval'), proj_mat)


if __name__ == '__main__':
    main()
