from collections import defaultdict
import glob
import os
import argparse
import re

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import torchvision.utils
import math

import copy
import datetime

from data import get_dataset_flickr, create_dataset, textprocess_train, textprocess
from epoch import evaluate_synset_with_similarity
from networks import CLIPModel_linear, MultilabelContrastiveLoss
from reparam_module import ReparamModule
from utils import ParamDiffAug, get_time
from similarity_mining import LowRankSimilarityGenerator, FullSimilarityGenerator
from vl_distill_utils import shuffle_files, nearest_neighbor, get_images_texts, load_or_process_file
from data.randaugment import RandomAugment
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image


def pre_caption(caption,max_words=50):
    caption = re.sub(
        r"([.!\"()*#:;~])",       
        ' ',
        caption.lower(),
    )
    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n') 
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])
            
    return caption


@torch.no_grad()
def textprocess_train_split(args, texts, idx):
    net = CLIPModel_linear(args).to('cuda')
    net.eval() 
    chunk_size = 1000
    chunks = []
    for i in tqdm(range(0, len(texts), chunk_size)):
        chunk = net.text_encoder(texts[i:i + chunk_size]).cpu()
        chunks.append(chunk)
        del chunk
        torch.cuda.empty_cache()  # free up memory
    bert_test_embed = torch.cat(chunks, dim=0)

    print('bert_test_embed.shape: ', bert_test_embed.shape)
    bert_test_embed_np = bert_test_embed.numpy()
    if args.dataset in ['flickr', 'coco']:
        np.savez(f'{args.dataset}_{args.text_encoder}_train_{idx}_text_embed.npz', bert_test_embed=bert_test_embed_np) 
    else:
        raise NotImplementedError
    return


@torch.no_grad()
def imageprocess_train(args, net, images):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    net.eval()
    
    normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))    
    transform_test = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        normalize,
        ])
    
    batch_size = 512
    image_root = args.image_root

    chunks = []
    pbar = tqdm(range(0, len(images), batch_size), desc="Encoding images")
    for i in pbar:
        batch_files = images[i: i + batch_size]
        pbar.set_postfix_str(f"first={batch_files[0]}")

        batch_imgs = []
        for name in batch_files:
            path = os.path.join(image_root, name)
            img = Image.open(path).convert("RGB")
            img = transform_test(img)
            batch_imgs.append(img)

        batch_tensor = torch.stack(batch_imgs).to(device, non_blocking=True)

        emb = net.image_encoder(batch_tensor).detach().cpu()
        chunks.append(emb)

        del batch_tensor, batch_imgs, emb
        torch.cuda.empty_cache()

    img_embeds = torch.cat(chunks, dim=0)
    print("img_embeds.shape:", img_embeds.shape)

    if args.dataset in ["flickr", "coco"]:
        out_name = f"{args.dataset}_{args.image_encoder}_train_image_embed.npz"
        np.savez(out_name, image_embed=img_embeds.numpy())
        print(f"Saved to {out_name}")
    else:
        raise NotImplementedError(f"Unsupported dataset name: {args.dataset}")



def make_timestamp(prefix: str="", suffix: str="") -> str:
    tmstamp = '{:%m%d_%H%M%S}'.format(datetime.datetime.now())
    return prefix + tmstamp + suffix


def split_images_and_captions(dataset):
    annotations = dataset.annotation
    images = []
    captions_split = [[] for _ in range(5)]
    
    i = 0
    while i < len(annotations):
        block = []
        img_name = annotations[i]['image']
        while i < len(annotations) and annotations[i]['image'] == img_name:
            block.append(annotations[i])
            i += 1
        
        while len(block) < 5:
            block.append(block[len(block) % len(block)])
        
        images.append(img_name)
        for k in range(5):
            cap = dataset.prompt + pre_caption(block[k]['caption'], dataset.max_words)
            captions_split[k].append(cap)
    
    return images, captions_split    


def Conv(img_embed, txt_embed, img_proj, txt_proj, alpha=0.1):
    device = img_embed.device
    N = img_embed.shape[0]

    # mu_I = img_embed.mean(0)
    # mu_T = txt_embed.mean(0)
    # mu_U = img_proj.mean(0)
    # mu_V = txt_proj.mean(0)
    
    h_I = img_embed - img_embed.mean(0, keepdim=True)
    h_T = txt_embed - txt_embed.mean(0, keepdim=True)
    h_U = img_proj  - img_proj.mean(0, keepdim=True)
    h_V = txt_proj  - txt_proj.mean(0, keepdim=True)

    sigma_II = (h_I.T @ h_I) / N + alpha * torch.eye(h_I.shape[1], device=device)
    sigma_IV = (h_I.T @ h_V) / N
    sigma_VV = (h_V.T @ h_V) / N + alpha * torch.eye(h_V.shape[1], device=device)

    tmp = torch.linalg.solve(sigma_II, sigma_IV)
    w_I = torch.linalg.solve(sigma_VV, tmp, left=False)

    sigma_TT = (h_T.T @ h_T) / N + alpha * torch.eye(h_T.shape[1], device=device)
    sigma_TU = (h_T.T @ h_U) / N
    sigma_UU = (h_U.T @ h_U) / N + alpha * torch.eye(h_U.shape[1], device=device)   
    
    tmp2 = torch.linalg.solve(sigma_TT, sigma_TU)
    w_T  = torch.linalg.solve(sigma_UU, tmp2, left=False)

    return w_I, w_T


def main(args):
    if args.dataset == 'flickr':
        args.image_root = '/root/autodl-tmp/data/Flickr30k'
        args.ann_root = '/root/autodl-tmp/data/Flickr30k_ann'
    elif args.dataset == 'coco':
        args.image_root = '/root/autodl-tmp/data/COCO'
        args.ann_root = '/root/autodl-tmp/data/Flickr30k_ann'
    else:
        return

    trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)
    train_sentences = train_dataset.get_all_captions()

    data = load_or_process_file('text', textprocess, args, testloader)
    train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences)

    bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu()
    print("The shape of bert_test_embed: {}".format(bert_test_embed.shape))
    train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu()
    print("The shape of train_caption_embed: {}".format(train_caption_embed.shape))

    ################### Split image and text #################
    train_dataset, val_dataset, test_dataset = create_dataset(args)
    images, captions_split = split_images_and_captions(train_dataset)

    # for idx in range(5):
    #     textprocess_train_split(args, captions_split[idx], idx)
    # imageprocess_train(args, images)

    net = CLIPModel_linear(args).cuda()
    net.load_state_dict(torch.load(f'buffers/{args.dataset}/nfnet_bert/InfoNCE/pretrain_9.pt', weights_only=True, map_location='cuda'))
    imageprocess_train(args, net, images)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')

    parser.add_argument('--dataset', type=str, default='flickr', help='dataset')
    parser.add_argument('--disabled_wandb', type=bool, default=True, help='disable wandb')

    parser.add_argument('--eval_mode', type=str, default='S',
                        help='eval_mode, check utils.py for more info')

    parser.add_argument('--num_eval', type=int, default=5, help='how many networks to evaluate on')

    parser.add_argument('--eval_it', type=int, default=50, help='how often to evaluate')

    parser.add_argument('--epoch_eval_train', type=int, default=100, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=3000, help='how many distillation steps to perform')

    parser.add_argument('--loss_type', type=str)
    
    parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks')

    parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"],
                        help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"],
                        help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.')

    parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path')
    parser.add_argument('--no_aug', action="store_true", default=False, help='this turns off diff aug during distillation')

    parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc')
    current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')
    parser.add_argument('--num_queries', type=int, default=100, help='number of queries')
    parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries')
    parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not')
    parser.add_argument('--n_basis', type=int, default=64, help='n_basis')
    parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not')
    parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
    parser.add_argument('--image_size', type=int, default=224, help='image_size')
        
    parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
    parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')
    
    parser.add_argument('--image_encoder', type=str, default='nfnet',  help='image encoder') # , choices=['clip', 'nfnet', 'vit', 'nf_resnet50', "nf_regnet"]
    parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert'], help='text encoder')

    parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
    parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')

    parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
    parser.add_argument('--image_trainable', type=bool, default=False, help='image_trainable') 

    parser.add_argument('--projection_dim', type=int, default=512, help='dimension of projection head')

    parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None')
    parser.add_argument('--distill', type=bool, default=True, help='whether distill')
    parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train')
    parser.add_argument('--image_only', type=bool, default=False, help='None')
    parser.add_argument('--text_only', type=bool, default=False, help='None')
    parser.add_argument('--draw', type=bool, default=False, help='None')
    parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture')
    parser.add_argument('--std', type=bool, default=True, help='standard deviation')
    parser.add_argument('--test_with_norm', type=bool, default=False, help='')

    parser.add_argument('--clamp_lr', type=float, default=None, help='')


    # Arguments below are for LoRS

    parser.add_argument('--resume_from', default=None, type=str)
    
    parser.add_argument('--sim_type', type=str, default="full", choices=["full", "lowrank"], help='similarity matrix type')
    parser.add_argument('--sim_rank', type=int, default=10, help='similarity matrix rank')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha in LoRA')
    parser.add_argument('--lr_sim', type=float, default=1e-03, help='learning rate for updating similarity mat learning rate')
    parser.add_argument('--temperature', type=float, default=0.07, help="temperature of CLIP model")
    
    parser.add_argument('--momentum_lr', type=float, default=0.5)
    parser.add_argument('--momentum_syn', type=float, default=0.5)
    parser.add_argument('--momentum_sim', type=float, default=0.5)
    parser.add_argument('--merge_loss_branches', action="store_true", default=False)
    
    args = parser.parse_args()

    main(args)
