from collections import defaultdict
import glob
import os
import argparse
import re

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import torchvision.utils
import math
import torch.nn.functional as F
from torchvision import transforms

import copy
import datetime

from data import get_dataset_flickr, textprocess, textprocess_train
from epoch import evaluate_synset_with_similarity, evaluate_synset
from networks import CLIPModel_full, CLIPModel_linear, MultilabelContrastiveLoss
from reparam_module import ReparamModule
from utils import ParamDiffAug, get_time
from similarity_mining import LowRankSimilarityGenerator, FullSimilarityGenerator
from vl_distill_utils import shuffle_files, nearest_neighbor, get_images_texts, load_or_process_file
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR

from hook import Conv


def svd_entropy(X, eps=1e-12):
    U, S, V = torch.linalg.svd(X, full_matrices=False)
    S = S[S > eps]

    probs = S / (S.sum() + eps)
    entropy = -(probs * (probs + eps).log()).sum()
    return entropy.item()


def make_timestamp(prefix: str="", suffix: str="") -> str:
    tmstamp = '{:%m%d_%H%M%S}'.format(datetime.datetime.now())
    return prefix + tmstamp + suffix


def main(args):
    if args.dataset == 'flickr':
        args.image_root = '/root/autodl-tmp/data/Flickr30k'
        args.ann_root = '/root/autodl-tmp/data/Flickr30k_ann'
    elif args.dataset == 'coco':
        args.image_root = '/root/autodl-tmp/data/COCO'
        args.ann_root = '/root/autodl-tmp/data/Flickr30k_ann'
    else:
        return

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)

    train_sentences = train_dataset.get_all_captions() 

    data = load_or_process_file('text', textprocess, args, testloader)
    train_caption = load_or_process_file('train_text', textprocess_train, args, train_sentences)

    bert_test_embed = torch.from_numpy(data['bert_test_embed']).cpu()
    print("The shape of bert_test_embed: {}".format(bert_test_embed.shape))
    train_caption_embed = torch.from_numpy(train_caption['bert_test_embed']).cpu()
    print("The shape of train_caption_embed: {}".format(train_caption_embed.shape))

    logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / args.temperature))

    if args.eval_it>0:
        eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
    else:
        eval_it_pool = []
    save_it_pool = np.arange(0, args.Iteration + 1, 100).tolist()

    ''' initialize the synthetic data '''
    image_syn, text_syn = get_images_texts(args.num_queries, train_dataset, args)

    ''' training '''
    image_syn = image_syn.detach().to(args.device).requires_grad_(True)
    text_syn = text_syn.detach().to(args.device).requires_grad_(True)

    optimizer = torch.optim.Adam([
        {'params': [image_syn], 'lr': 0.1, 'betas': (0.6, 0.9)},
        {'params': [text_syn], 'lr': 0.1, 'betas': (0.6, 0.9)},
    ], eps=1e-8)
    optimizer.zero_grad()
    # scheduler = CosineAnnealingLR(optimizer, args.Iteration, eta_min=1e-2)
    # scheduler = StepLR(optimizer, step_size=args.Iteration//2, gamma=0.1)

    ### Model Initilization ###
    net = CLIPModel_linear(args, temperature=args.temperature)
    net = net.to(args.device)

    net.eval()
    for p in net.parameters():
        p.requires_grad = False

    img_projectors = []
    txt_projectors = []
    model_ids = [29]
    model_num = len(model_ids)

    for model_id in model_ids:
        weights = torch.load(f'/root/autodl-tmp/buffers/{args.dataset}/nfnet_bert/InfoNCE/pretrain_{model_id}.pt', weights_only=True, map_location=args.device)
        img_projectors.append(weights['image_projection.weight'].data)
        txt_projectors.append(weights['text_projection.weight'].data)

    args.projection_dim = img_projectors[0].shape[1]
    print(img_projectors[0].shape, txt_projectors[0].shape)

    ### Statistics ###
    text_embeds = []
 
    image_filename = f'statistics/{args.dataset}_{args.image_encoder}_train_image_embed.npz'
    image_embed = torch.from_numpy(np.load(image_filename)['image_embed']).to(args.device)

    for t in range(5):
        text_filename = f'statistics/{args.dataset}_{args.text_encoder}_train_{t}_text_embed.npz'
        bert_embed = torch.from_numpy(np.load(text_filename)['bert_test_embed']).to(args.device)
        text_embeds.append(bert_embed)

    pre_calculations = [[] for _ in range(model_num)]

    for model_id in range(model_num):
    
        img_proj = image_embed @ img_projectors[model_id]
        img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)

        for t in range(5):
            txt_proj = text_embeds[t] @ txt_projectors[model_id]
            txt_proj = txt_proj / txt_proj.norm(dim=1, keepdim=True)

            stats = Conv(image_embed, text_embeds[t], img_proj, txt_proj, alpha=args.alpha)
            pre_calculations[model_id].append([s.detach() for s in stats])


    ### Distillation ###
    for it in tqdm(range(args.Iteration + 1)):
        save_this_it = True

        ''' Evaluate synthetic data '''
        if it in eval_it_pool:
            print('Evaluation\nimage_model_train = %s, text_model_train = %s, iteration = %d'%(args.image_encoder, args.text_encoder, it))

            multi_eval_aggr_result = defaultdict(list)  # aggregated results of multiple evaluations

            # r_means = []
            for it_eval in range(args.num_eval):
                # net_eval = CLIPModel_full(args, eval_stage=args.transfer)
                net_eval = CLIPModel_linear(args, eval_stage=args.transfer)

                with torch.no_grad():
                    image_save = image_syn
                    text_save = text_syn
                image_syn_eval, text_syn_eval = copy.deepcopy(image_save.detach()), copy.deepcopy(text_save.detach())

                img_embed = net.image_encoder(image_syn_eval)
                txt_embed = text_syn_eval
                
                sim_gts = []
                for model_id in range(model_num):
                    img_proj = img_embed @ img_projectors[model_id]
                    txt_proj = txt_embed @ txt_projectors[model_id]

                    img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                    txt_proj = txt_proj / txt_proj.norm(dim=1, keepdim=True)

                    image_logits = logit_scale.exp() * img_proj.float() @ txt_proj.float().T
                    sim_gts.append(image_logits.detach())

                similarity_syn_eval = copy.deepcopy(torch.mean(torch.stack(sim_gts, dim=0), dim=0))

                _, _, best_val_result = evaluate_synset_with_similarity(
                    it_eval, net_eval, image_syn_eval, text_syn_eval, args.lr_teacher_img, args.lr_teacher_txt,
                    similarity_syn_eval, testloader, args, bert_test_embed)

                # _, _, best_val_result = evaluate_synset(
                #     it_eval, net_eval, image_save, text_save, 
                #     testloader, args, bert_test_embed, return_loss=False)

                for k, v in best_val_result.items():
                    multi_eval_aggr_result[k].append(v)

                for key, values in multi_eval_aggr_result.items():
                    if key in ["img_r_mean", "txt_r_mean"]:
                        continue
                    

        if it in save_it_pool and (save_this_it or (it+1) % 1000 == 0):
            with torch.no_grad():
                save_dir = os.path.join(".", "logged_files", args.dataset)
                print("Saving to {}".format(save_dir))
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)

                img_embed = net.image_encoder(image_syn.detach())
                txt_embed = text_syn.detach()
            
                sim_gts = []
                for model_id in range(model_num):
                    img_proj = img_embed @ img_projectors[model_id]
                    txt_proj = txt_embed @ txt_projectors[model_id]

                    img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                    txt_proj = txt_proj / txt_proj.norm(dim=1, keepdim=True)

                    image_logits = logit_scale.exp() * img_proj.float() @ txt_proj.float().T
                    sim_gts.append(image_logits.detach())

                similarity_syn_eval = copy.deepcopy(torch.mean(torch.stack(sim_gts, dim=0), dim=0))

                image_save = image_syn.detach().cpu()
                text_save = text_syn.detach().cpu()

                torch.save({
                    "image": image_save,
                    "text": text_save,
                    # "similarity_params": [x.detach().cpu() for x in sim_params],
                    "similarity_mat": similarity_syn_eval.detach().cpu(),
                    "img_proj": img_proj,
                    "txt_proj": txt_proj,
                    # "syn_lr_img": syn_lr_img.detach().cpu(),
                    # "syn_lr_txt": syn_lr_txt.detach().cpu(),
                }, os.path.join(save_dir, "distilled_{}.pt".format(it)) )

                if args.draw:
                    if args.force_save:
                        upsampled = image_save[:90]
                        if args.dataset != "ImageNet":
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                        grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                        sentence_list = nearest_neighbor(train_sentences, text_syn.cpu(), train_caption_embed)
                        sentence_list = sentence_list[:90]
                        torchvision.utils.save_image(grid, os.path.join(save_dir, "synthetic_images_{}.png".format(it)))
                        
                        with open(os.path.join(save_dir, "synthetic_sentences_{}.txt".format(it)), "w") as file:
                            file.write('\n'.join(sentence_list))
                        print("finish saving images")

                        for clip_val in [2.5]:
                            std = torch.std(image_save)
                            mean = torch.mean(image_save)
                            upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std).cpu()  # Move to CPU
                            if args.dataset != "ImageNet":
                                upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                                upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                            grid = torchvision.utils.make_grid(upsampled[:90], nrow=10, normalize=True, scale_each=True)
                            torchvision.utils.save_image(grid, os.path.join(save_dir, "clipped_synthetic_images_{}_std_{}.png".format(it, clip_val)))

        img_embed = net.image_encoder(image_syn)
        txt_embed = text_syn

        img_proj = img_embed @ img_projectors[(it // 5) % model_num]
        txt_proj = txt_embed @ txt_projectors[(it // 5) % model_num]

        img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
        txt_proj = txt_proj / txt_proj.norm(dim=1, keepdim=True)

        image_logits = logit_scale.exp() * img_proj.float() @ txt_proj.float().T
        ground_truth = torch.arange(len(image_logits)).type_as(image_logits).long()        
        loss_cls = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2

        post_calculations = Conv(img_embed, txt_embed, img_proj, txt_proj, alpha=args.alpha)
        match_losses = []
        for i in range(len(post_calculations)):
            if len(post_calculations[i].shape) == 1:
                match_loss = torch.norm(post_calculations[i] - pre_calculations[(it // 5) % model_num][it % 5][i], 2)
                match_losses.append(match_loss)
            elif len(post_calculations[i].shape) == 2:
                match_loss = torch.norm(post_calculations[i] - pre_calculations[(it // 5) % model_num][it % 5][i], 'fro')
                match_losses.append(match_loss)
            else:
                return

        loss_match = sum([mod for (idx, mod) in enumerate(match_losses)])

        loss = loss_cls + 0.01 * loss_match
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # scheduler.step()

        if it%50 == 0:
            conv_img = torch.cov(img_proj, correction=0)
            conv_txt = torch.cov(txt_proj, correction=0)

            print([round(m.item(), 4) for m in match_losses] + [svd_entropy(conv_img), svd_entropy(conv_txt)])
            # print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss.item()))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')

    parser.add_argument('--dataset', type=str, default='flickr', help='dataset')
    parser.add_argument('--disabled_wandb', type=bool, default=True, help='disable wandb')

    parser.add_argument('--eval_mode', type=str, default='S',
                        help='eval_mode, check utils.py for more info')

    parser.add_argument('--num_eval', type=int, default=1, help='how many networks to evaluate on')

    parser.add_argument('--eval_it', type=int, default=0, help='how often to evaluate')

    parser.add_argument('--epoch_eval_train', type=int, default=100, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=400, help='how many distillation steps to perform')

    parser.add_argument('--lr_img', type=float, default=100., help='learning rate for updating synthetic images')
    parser.add_argument('--lr_txt', type=float, default=100., help='learning rate for updating synthetic texts')
    parser.add_argument('--lr_lr', type=float, default=1e-03, help='learning rate for updating... learning rate')
    parser.add_argument('--lr_teacher_img', type=float, default=0.1, help='learning rate for updating network parameters')
    parser.add_argument('--lr_teacher_txt', type=float, default=0.1, help='learning rate for updating network parameters')

    parser.add_argument('--loss_type', default='WBCE', type=str)
    
    parser.add_argument('--batch_train', type=int, default=128, help='batch size for training networks')

    parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"],
                        help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--txt_init', type=str, default='real', choices=["noise", "real"],
                        help='noise/real: initialize synthetic texts from random noise or randomly sampled real images.')

    parser.add_argument('--data_path', type=str, default='./data/Flickr30k/', help='dataset path')
    parser.add_argument('--no_aug', action="store_true", default=False, help='this turns off diff aug during distillation')

    parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc')
    current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')
    parser.add_argument('--num_queries', type=int, default=199, help='number of queries')
    parser.add_argument('--mini_batch_size', type=int, default=100, help='number of queries')
    parser.add_argument('--basis', type=bool, default=False, help='whether use basis or not')
    parser.add_argument('--n_basis', type=int, default=64, help='n_basis')
    parser.add_argument('--recursive', type=bool, default=False, help='whether use basis or not')
    parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
    parser.add_argument('--image_size', type=int, default=224, help='image_size')

    parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
    parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')

    parser.add_argument('--image_encoder', type=str, default='nfnet',  help='image encoder') # , choices=['clip', 'nfnet', 'vit', 'nf_resnet50', "nf_regnet"]
    parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert'], help='text encoder')
    
    parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
    parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')
    
    parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
    parser.add_argument('--image_trainable', type=bool, default=False, help='image_trainable') 

    parser.add_argument('--projection_dim', type=int, default=256, help='dimension of projection head')
    parser.add_argument('--alpha', type=float, default=0.05, help='alpha in inverse matrix')

    parser.add_argument('--distill', type=bool, default=True, help='whether distill')
    parser.add_argument('--optimize', type=str, default='reparam', choices=['reparam', 'ift'], help='matching_train')
    parser.add_argument('--image_only', type=bool, default=False, help='None')
    parser.add_argument('--text_only', type=bool, default=False, help='None')
    parser.add_argument('--draw', type=bool, default=False, help='None')
    parser.add_argument('--transfer', type=bool, default=False, help='transfer cross architecture')
    parser.add_argument('--std', type=bool, default=True, help='standard deviation')
    parser.add_argument('--test_with_norm', type=bool, default=False, help='')

    parser.add_argument('--clamp_lr', type=float, default=None, help='')
    parser.add_argument('--temperature', type=float, default=0.07, help="temperature of CLIP model")

    # Arguments below are for LoRS
    # parser.add_argument('--resume_from', default=None, type=str)
    
    # parser.add_argument('--sim_type', type=str, default="full", choices=["full", "lowrank"], help='similarity matrix type')
    # parser.add_argument('--sim_rank', type=int, default=10, help='similarity matrix rank')
    # parser.add_argument('--lr_sim', type=float, default=1e-03, help='learning rate for updating similarity mat learning rate')
    
    # parser.add_argument('--momentum_lr', type=float, default=0.5)
    # parser.add_argument('--momentum_syn', type=float, default=0.5)
    # parser.add_argument('--momentum_sim', type=float, default=0.5)
    # parser.add_argument('--merge_loss_branches', action="store_true", default=False)
    
    args = parser.parse_args()

    main(args)
