import numpy as np
import torch
import time
from PIL import Image
import os
import time
import pickle
import copy
from torchvision.utils import save_image
from torchvision.transforms import Resize, ToTensor, CenterCrop
from support.gen_img import gen_img_from_text
from infer_text import GenTxtFromClipFea
from optmize_img_fea import optimize_img_embedding_by_influence, get_sim_between_img_txt, get_sim_between_img_txt_iterative
from support.utils import get_labs_from_dataloader, get_features, get_class_centers, get_cand_text, init_seeds, get_highest_score_fea
from model_train import get_train_lr


def optimize_img_embedding(clip_model, cla_model, lab_loader, loss_fun, train_fea, train_lab, val_loader, class_to_idx_dict, idx_to_class, class_with_min_instances, text_generator):
    # generate txt description using prompt learner (validate)
    desired_class_name = idx_to_class[class_with_min_instances]
    desired_class = class_with_min_instances.item() if isinstance(class_with_min_instances, torch.Tensor) else int(class_with_min_instances)
    all_other_classes = list(class_to_idx_dict.keys())
    all_other_classes.remove(desired_class_name)

    # optimize img embedding
    val_x, val_y = get_features(clip_model, val_loader, return_class_center=False)
    
    class_data_ids, centers = get_class_centers(train_fea.cpu().numpy(), train_lab.cpu().numpy())
    img_origin = torch.as_tensor(centers[class_with_min_instances])
    print(desired_class_name)
    train_lr = get_train_lr(lab_loader)
    cand_img_fea_dict = optimize_img_embedding_by_influence(cla_model=cla_model, clip_model=clip_model, text_generator=text_generator,
                                                            loss=loss_fun, train_fea=train_fea, train_lab=train_lab,
                                                            val_x=torch.as_tensor(val_x), val_y=torch.as_tensor(val_y),
                                                            desired_class=class_with_min_instances, desired_class_name=desired_class_name,
                                                            negative_class_name=all_other_classes, train_lr=train_lr,
                                                            img_origin=img_origin, step_size=1e-4, 
                                                            optimize_iterations=200)
    return cand_img_fea_dict


def our_query(dataset, lab_idx, unlab_idx, lab_loader, unlab_loader, val_loader, clip_model, cla_model, 
              loss_fun, train_fea, train_lab, query_size, class_to_idx_dict, gen_class_txt_num=1, save_img_flag=False, 
              use_cache=True, save_dir="", query_iter=0, ext_labels=None, method_dir_name=''):
    init_seeds(int(time.time()))
    idx_to_class = {v: k for k, v in class_to_idx_dict.items()}
    if gen_class_txt_num >= len(idx_to_class):
        sorted_class_count = list(idx_to_class.keys())
    else:
        sorted_class_count = np.random.choice(list(idx_to_class.keys()), gen_class_txt_num, replace=False)
    all_txt_desc_list = []
    text_generator = GenTxtFromClipFea(clip_model=clip_model)
    for i in range(gen_class_txt_num):
        desired_class = sorted_class_count[i]
        desired_class_name = idx_to_class[sorted_class_count[i]]
        if use_cache:
            try:
                with open(os.path.join(save_dir, f"gen_txt_{i}_{desired_class_name}.txt"), "r") as f:
                    txt_descriptions_listi = f.read().splitlines()
                all_txt_desc_list.append(txt_descriptions_listi)
                continue
            except FileNotFoundError:
                print("Cache not found, generating new txt descriptions.")

        cand_img_fea_dict = optimize_img_embedding(clip_model, cla_model, lab_loader, loss_fun, train_fea, train_lab, val_loader, class_to_idx_dict, idx_to_class, desired_class, text_generator)
        
        txt_descriptions_listi = get_cand_text(cand_img_fea_dict, cand_num=1)
        with open(os.path.join(save_dir, f"gen_txt_{i}_{desired_class_name}.txt"), "w") as f:
            f.write("\n".join(txt_descriptions_listi))
                
        all_txt_desc_list.append(txt_descriptions_listi)

    for i in range(gen_class_txt_num):
        gen_img_from_text(all_txt_desc_list[i], idx_to_class[sorted_class_count[i]], f"./gen_img_save/{dataset}/{method_dir_name}/{query_iter}/", max_num=30 if '1k' in dataset else 10)
    selected_idx = []
    
    return selected_idx


def our_query_fea(lab_loader, val_loader, clip_model, cla_model, 
              loss_fun, train_fea, train_lab, class_to_idx_dict, gen_class_txt_num=1,
              use_cache=True, save_dir=""):
    
    idx_to_class = {v: k for k, v in class_to_idx_dict.items()}
    if gen_class_txt_num >= len(idx_to_class):
        sorted_class_count = list(idx_to_class.keys())
    else:
        sorted_class_count = np.random.choice(list(idx_to_class.keys()), gen_class_txt_num, replace=False)
    text_generator = GenTxtFromClipFea(clip_model=clip_model)
    gen_imgs_all = []
    gen_labels_all = []
    for i in range(gen_class_txt_num):
        desired_class = sorted_class_count[i]
        desired_class_name = idx_to_class[sorted_class_count[i]]
        if use_cache:
            try:
                with open(os.path.join(save_dir, f"cand_{i}_{desired_class_name}.pkl"), "rb") as f:
                    cand_img_fea_dict = pickle.load(f)
                    ret_fea, highest_score = get_highest_score_fea(cand_img_fea_dict)
                    gen_imgs_all.append(ret_fea)
                    gen_labels_all.append(desired_class)
                continue
            except FileNotFoundError:
                print("Cache not found, generating new txt descriptions.")

        cand_img_fea_dict = optimize_img_embedding(clip_model, cla_model, lab_loader, loss_fun, train_fea, train_lab, val_loader, class_to_idx_dict, idx_to_class, desired_class, text_generator)
        
        with open(os.path.join(save_dir, f"cand_{i}_{desired_class_name}.pkl"), "wb") as f:
            pickle.dump(cand_img_fea_dict, f)
        ret_fea, highest_score = get_highest_score_fea(cand_img_fea_dict)
        # gen_fea_dict[ret_fea] = desired_class
        gen_imgs_all.append(ret_fea)
        gen_labels_all.append(desired_class)
    
    gen_imgs_all = torch.vstack(gen_imgs_all)
    gen_labels_all = torch.as_tensor(gen_labels_all)
    # with open(os.path.join(save_dir, "gen_fea_dict.pkl"), "wb") as f:
    #     pickle.dump(gen_fea_dict, f)
    with open(os.path.join(save_dir, "gen_fea_all.pkl"), "wb") as f:
        pickle.dump(gen_imgs_all.cpu(), f)
    with open(os.path.join(save_dir, "gen_label_all.pkl"), "wb") as f:
        pickle.dump(gen_labels_all.cpu(), f)
    return []
